├── .gitignore
├── INSTALL.md
├── LICENSE
├── README.md
├── data
└── list
│ └── cityscapes
│ ├── test.lst
│ ├── train.lst
│ ├── trainval.lst
│ └── val.lst
├── experiments
└── cityscapes
│ ├── w18.yaml
│ └── w48.yaml
├── lib
├── config
│ ├── __init__.py
│ ├── default.py
│ └── models.py
├── core
│ ├── criterion.py
│ └── function.py
├── datasets
│ ├── __init__.py
│ ├── base_dataset.py
│ └── cityscapes.py
├── models
│ ├── __init__.py
│ ├── conv_mask.py
│ ├── model_anytime.py
│ └── sync_bn
│ │ ├── __init__.py
│ │ └── inplace_abn
│ │ ├── __init__.py
│ │ ├── bn.py
│ │ ├── functions.py
│ │ └── src
│ │ ├── common.h
│ │ ├── inplace_abn.cpp
│ │ ├── inplace_abn.h
│ │ ├── inplace_abn_cpu.cpp
│ │ └── inplace_abn_cuda.cu
└── utils
│ ├── __init__.py
│ ├── metric.py
│ ├── modelsummary.py
│ └── utils.py
├── requirements.txt
└── tools
├── _init_paths.py
├── test_ee.py
└── train_ee.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__/
2 | pretrained_models/
3 | scripts/
4 | output_new/
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation
2 | We provide installation instructions for Cityscapes segmentation experiments here.
3 | ## Dependency Setup
4 | Create a new conda virtual environment
5 | ```
6 | conda create -n anytime python=3.8 -y
7 | conda activate anytime
8 | ```
9 | Install `PyTorch=1.1.0`
10 | ```
11 | pip install torch==1.1.0
12 | ```
13 | Clone this repo and install required packages
14 | ```
15 | pip install -r requirements.txt
16 | ```
17 |
18 | ## Data preparation
19 | Download the [Cityscapes](https://www.cityscapes-dataset.com/) dataset and place a symbolic link under the `data` folder.
20 |
21 | ```
22 | mkdir data
23 | ln -s $DATA_ROOT data
24 | ```
25 |
26 | Structure the data as follows
27 | ````
28 | $ROOT/data
29 | └── cityscapes
30 | ├── gtFine
31 | │ ├── test
32 | │ ├── train
33 | │ └── val
34 | └── leftImg8bit
35 | ├── test
36 | ├── train
37 | └── val
38 |
39 | ````
40 |
41 | ## Pretrained model preparation
42 | Create a folder named `pretrained_models` under the root directory.
43 | ```
44 | mkdir pretrained_models
45 | ```
46 | Download the [HRNet-W18-C-Small-v2](https://1drv.ms/u/s!Aus8VCZ_C_33gRmfdPR79WBS61Qn?e=HVZUi8) and [HRNet-W48-C](https://1drv.ms/u/s!Aus8VCZ_C_33dKvqI6pBZlifgJk) from [HRNet-Image-Classification](https://github.com/HRNet/HRNet-Image-Classification.git)
47 | and structure the directory as follows
48 | ```
49 | pretrained_models
50 | ├── hrnet_w18_small_model_v2.pth
51 | └── hrnetv2_w48_imagenet_pretrained.pth
52 | ```
53 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Zhuang Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [Anytime Dense Prediction with Confidence Adaptivity](https://arxiv.org/abs/2104.00749)
2 |
3 | Official PyTorch implementation for the following paper:
4 |
5 | [Anytime Dense Prediction with Confidence Adaptivity](https://openreview.net/forum?id=kNKFOXleuC). ICLR 2022.\
6 | [Zhuang Liu](https://liuzhuang13.github.io), [Zhiqiu Xu](https://www.linkedin.com/in/oscar-xu-1250821a1/), [Hung-ju Wang](https://www.linkedin.com/in/hungju-wang-5a5124172/), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/) and [Evan Shelhamer](http://imaginarynumber.net/)\
7 | UC Berkeley, Adobe Research
8 |
9 | Our implementation is based upon [HRNet-Semantic-Segmentation](https://github.com/HRNet/HRNet-Semantic-Segmentation/tree/pytorch-v1.1).
10 |
11 | ---
12 |
13 |
15 |
16 |
17 | Our full method, named **Anytime Dense Prediction with Confidence (ADP-C)**, achieves the same level of final accuracy with HRNet-w48, and meanwhile significantly reduces total computation.
18 |
19 | ### Main Results
20 |
21 |
22 | | Setting (HRNet-W48) | model | exit1 | exit2 | exit3 | exit4 | mean mIoU | exit1 | exit2 | exit3 | exit4 | mean GFLOPs |
23 | | ------------------------- | :---: | :---: | :---: | :---: | :------: | :---------: | :---: | :---: | :---: | :-------: | :---------: |
24 | | HRNet-W48 | - | - | - | 80.7 | - | - | - | - | 696.2 | - |
25 | | EE | [model](https://drive.google.com/file/d/1GOXuP0e-qDp1mqiilxdhL8FyP8E-6pZP/view?usp=sharing) | 34.3 | 59.0 | 76.9 | 80.4 | 62.7 | 521.6 | 717.9 | 914.2 | 1110.5 | 816.0 |
26 | | EE + RH | [model](https://drive.google.com/file/d/11QNuEpq-oBErMKO3eMEddAU8ug9OYyts/view?usp=sharing) | 44.6 | 60.2 | 76.6 | 79.9 | 65.3 | 41.9 | 105.6 | 368.0 | 701.3 | 304.2 |
27 | | ADP-C: EE + RH + CA | [model](https://drive.google.com/file/d/1zcKkKWuknrLHpOEVRvUm82xlowjafQ_u/view?usp=sharing) | 44.3 | 60.1 | 76.8 | **81.3** | **65.7** | 41.9 | 93.9 | 259.3 | **387.1** | **195.6** |
28 |
29 |
30 |
31 | ## Installation
32 | Please check [INSTALL.md](INSTALL.md) for installation instructions.
33 |
34 | ## Evaluation on pretrained models
35 |
36 | Download our pretrained model from the table above and specify its location by `TEST.MODEL_FILE`
37 |
38 | **Early Exits (EE)**
39 | ```bash
40 | python tools/test_ee.py --cfg experiments/cityscapes/w48.yaml \
41 | TEST.MODEL_FILE .pth
42 | ```
43 | This should give
44 | ```
45 | 34.33 59.01 76.90 80.43 62.67
46 | ```
47 |
48 | **Redesigned Heads (RH)**
49 | ```bash
50 | python tools/test_ee.py --cfg experiments/cityscapes/w48.yaml \
51 | EXIT.TYPE 'flex' EXIT.INTER_CHANNEL 128 \
52 | TEST.MODEL_FILE .pth
53 | ```
54 |
55 | This should give
56 | ```
57 | 44.61 60.19 76.64 79.89 65.33
58 | ```
59 |
60 | **ADP-C (EE + RH + CA)**
61 | ```bash
62 | python tools/test_ee.py \
63 | --cfg experiments/cityscapes/w48.yaml MODEL.NAME model_anytime \
64 | EXIT.TYPE 'flex' EXIT.INTER_CHANNEL 128 \
65 | MASK.USE True MASK.CONF_THRE 0.998 \
66 | TEST.MODEL_FILE .pth
67 | ```
68 |
69 | This should give
70 | ```
71 | 44.34 60.13 76.82 81.31 65.65
72 | ```
73 |
74 | **ADP-C (EE + RH + CA)** (w18) [Pretrained w18 with ADP-C](https://drive.google.com/file/d/1bU7spRV236OV7D5dgZzHy_AI0FN3oGAp/view?usp=sharing)
75 | ```bash
76 | python tools/test_ee.py \
77 | --cfg experiments/cityscapes/w18.yaml MODEL.NAME model_anytime \
78 | EXIT.TYPE 'flex' EXIT.INTER_CHANNEL 64 \
79 | MASK.USE True MASK.CONF_THRE 0.998 \
80 | TEST.MODEL_FILE .pth
81 | ```
82 |
83 | This should give
84 | ```
85 | 40.83 48.19 68.26 77.02 58.57
86 | ```
87 |
88 |
89 | ## Train
90 |
91 | There are two configurations for the backbone HRnet model, `w48.yaml` and `w18.yaml` under `experimens/cityscapes`. Note that the following commands are for using `HRNet-w48` as backbone. Please change `EXIT.INTER_CHANNEL` to `64` when using `w18` as backbone.
92 |
93 | **Early Exits (EE)**
94 |
95 | ````bash
96 | python -m torch.distributed.launch tools/train_ee.py \
97 | --cfg experiments/cityscapes/w48.yaml
98 | ````
99 |
100 |
101 | **Redesigned Heads (RH)**
102 |
103 | ````bash
104 | python -m torch.distributed.launch tools/train_ee.py \
105 | --cfg experiments/cityscapes/w48.yaml \
106 | EXIT.TYPE 'flex' EXIT.INTER_CHANNEL 128
107 | ````
108 |
109 |
110 | **Confidence Adatative (CA)**
111 |
112 | ````bash
113 | python -m torch.distributed.launch tools/train_ee.py \
114 | --cfg experiments/cityscapes/w48.yaml \
115 | MASK.USE True MASK.CONF_THRE 0.998
116 | ````
117 |
118 |
119 | **ADP-C (EE + RH + CA)**
120 |
121 | ````bash
122 | python -m torch.distributed.launch tools/train_ee.py \
123 | --cfg experiments/cityscapes/w48.yaml \
124 | EXIT.TYPE 'flex' EXIT.INTER_CHANNEL 128 \
125 | MASK.USE True MASK.CONF_THRE 0.998
126 | ````
127 |
128 | Evaulation results will be generated at the end of training.
129 |
130 | - `result.txt`: contains mIOU for each exit and the average mIOU of the four exits.
131 |
132 | - `test_stats.json`: contains FLOPs and number of parameters.
133 |
134 | - `final_state.pth`: the trained model file.
135 |
136 | - `config.yaml`: the configuration file.
137 |
138 | ## Test
139 |
140 | **Evaluation**
141 |
142 | ```
143 | python tools/test_ee.py --cfg /config.yaml
144 | ```
145 |
146 | ## Acknowledgement
147 | This repository is built upon [HRNet-Semantic-Segmentation](https://github.com/HRNet/HRNet-Semantic-Segmentation/tree/pytorch-v1.1).
148 |
149 | ## License
150 | This project is released under the MIT license. Please see the [LICENSE](LICENSE) file for more information.
151 |
152 | ## Citation
153 | If you find this repository helpful, please consider citing:
154 | ```
155 | @Article{liu2022anytime,
156 | author = {Zhuang Liu and Zhiqiu Xu and Hung-Ju Wang and Trevor Darrell and Evan Shelhamer},
157 | title = {Anytime Dense Prediction with Confidence Adaptivity},
158 | journal = {International Conference on Learning Representations (ICLR)},
159 | year = {2022},
160 | }
161 | ```
162 |
--------------------------------------------------------------------------------
/experiments/cityscapes/w18.yaml:
--------------------------------------------------------------------------------
1 | CUDNN:
2 | BENCHMARK: true
3 | DETERMINISTIC: false
4 | ENABLED: true
5 | GPUS: (0,1,2,3)
6 | OUTPUT_DIR: 'output'
7 | LOG_DIR: 'log'
8 | WORKERS: 1
9 | PRINT_FREQ: 10
10 |
11 | MASK:
12 | USE: false
13 | INTERPOLATION: rbf
14 | P: 0.5
15 | CONF_THRE: 0.0
16 | ENTROPY_THRE: 0.0
17 | CRIT: conf_thre
18 | AGGR: copy
19 |
20 | EXIT:
21 | TYPE: original
22 | FINAL_CONV_KERNEL: 1
23 |
24 | DATASET:
25 | DATASET: cityscapes
26 | ROOT: 'data/'
27 | TEST_SET: 'list/cityscapes/val.lst'
28 | TRAIN_SET: 'list/cityscapes/train.lst'
29 | NUM_CLASSES: 19
30 | MODEL:
31 | NAME: 'model_anytime'
32 | PRETRAINED: 'pretrained_models/hrnet_w18_small_model_v2.pth'
33 | LOAD_STAGE: 0
34 | EXTRA:
35 | EE_WEIGHTS: (1,1,1,1)
36 | AGGREGATION: none
37 | EARLY_DETACH: false
38 | EXIT_NORM: BN
39 | STAGE1:
40 | NUM_MODULES: 1
41 | NUM_RANCHES: 1
42 | BLOCK: BOTTLENECK
43 | NUM_BLOCKS:
44 | - 2
45 | NUM_CHANNELS:
46 | - 64
47 | FUSE_METHOD: SUM
48 | STAGE2:
49 | NUM_MODULES: 1
50 | NUM_BRANCHES: 2
51 | BLOCK: BASIC
52 | NUM_BLOCKS:
53 | - 2
54 | - 2
55 | NUM_CHANNELS:
56 | - 18
57 | - 36
58 | FUSE_METHOD: SUM
59 | STAGE3:
60 | NUM_MODULES: 3
61 | NUM_BRANCHES: 3
62 | BLOCK: BASIC
63 | NUM_BLOCKS:
64 | - 2
65 | - 2
66 | - 2
67 | NUM_CHANNELS:
68 | - 18
69 | - 36
70 | - 72
71 | FUSE_METHOD: SUM
72 | STAGE4:
73 | NUM_MODULES: 2
74 | NUM_BRANCHES: 4
75 | BLOCK: BASIC
76 | NUM_BLOCKS:
77 | - 2
78 | - 2
79 | - 2
80 | - 2
81 | NUM_CHANNELS:
82 | - 18
83 | - 36
84 | - 72
85 | - 144
86 | FUSE_METHOD: SUM
87 | LOSS:
88 | USE_OHEM: false
89 | OHEMTHRES: 0.9
90 | OHEMKEEP: 131072
91 | TRAIN:
92 | EE_ONLY: false
93 | ALLE_ONLY: false
94 | IMAGE_SIZE:
95 | - 1024
96 | - 512
97 | BASE_SIZE: 2048
98 | BATCH_SIZE_PER_GPU: 3
99 | SHUFFLE: true
100 | BEGIN_EPOCH: 0
101 | END_EPOCH: 484
102 | RESUME: false
103 | OPTIMIZER: sgd
104 | LR: 0.01
105 | WD: 0.0005
106 | MOMENTUM: 0.9
107 | NESTEROV: false
108 | FLIP: true
109 | MULTI_SCALE: true
110 | DOWNSAMPLERATE: 1
111 | IGNORE_LABEL: 255
112 | SCALE_FACTOR: 16
113 | TEST:
114 | SUB_DIR: ''
115 | IMAGE_SIZE:
116 | - 2048
117 | - 1024
118 | BASE_SIZE: 2048
119 | BATCH_SIZE_PER_GPU: 4
120 | CENTER_CROP_TEST: false
121 |
122 |
--------------------------------------------------------------------------------
/experiments/cityscapes/w48.yaml:
--------------------------------------------------------------------------------
1 | CUDNN:
2 | BENCHMARK: true
3 | DETERMINISTIC: false
4 | ENABLED: true
5 | GPUS: (0,1,2,3)
6 | OUTPUT_DIR: 'output'
7 | LOG_DIR: 'log'
8 | WORKERS: 1
9 | PRINT_FREQ: 10
10 |
11 | MASK:
12 | USE: false
13 | INTERPOLATION: rbf
14 | P: 0.5
15 | CONF_THRE: 0.0
16 | ENTROPY_THRE: 0.0
17 | CRIT: conf_thre
18 | AGGR: copy
19 |
20 | EXIT:
21 | TYPE: original
22 | FINAL_CONV_KERNEL: 1
23 |
24 | DATASET:
25 | DATASET: cityscapes
26 | ROOT: 'data/'
27 | TEST_SET: 'list/cityscapes/val.lst'
28 | TRAIN_SET: 'list/cityscapes/train.lst'
29 | NUM_CLASSES: 19
30 | MODEL:
31 | NAME: 'model_anytime'
32 | PRETRAINED: 'pretrained_models/hrnetv2_w48_imagenet_pretrained.pth'
33 | LOAD_STAGE: 0
34 | EXTRA:
35 | EE_WEIGHTS: (1,1,1,1)
36 | AGGREGATION: none
37 | EARLY_DETACH: false
38 | EXIT_NORM: BN
39 | STAGE1:
40 | NUM_MODULES: 1
41 | NUM_RANCHES: 1
42 | BLOCK: BOTTLENECK
43 | NUM_BLOCKS:
44 | - 4
45 | NUM_CHANNELS:
46 | - 64
47 | FUSE_METHOD: SUM
48 | STAGE2:
49 | NUM_MODULES: 1
50 | NUM_BRANCHES: 2
51 | BLOCK: BASIC
52 | NUM_BLOCKS:
53 | - 4
54 | - 4
55 | NUM_CHANNELS:
56 | - 48
57 | - 96
58 | FUSE_METHOD: SUM
59 | STAGE3:
60 | NUM_MODULES: 4
61 | NUM_BRANCHES: 3
62 | BLOCK: BASIC
63 | NUM_BLOCKS:
64 | - 4
65 | - 4
66 | - 4
67 | NUM_CHANNELS:
68 | - 48
69 | - 96
70 | - 192
71 | FUSE_METHOD: SUM
72 | STAGE4:
73 | NUM_MODULES: 3
74 | NUM_BRANCHES: 4
75 | BLOCK: BASIC
76 | NUM_BLOCKS:
77 | - 4
78 | - 4
79 | - 4
80 | - 4
81 | NUM_CHANNELS:
82 | - 48
83 | - 96
84 | - 192
85 | - 384
86 | FUSE_METHOD: SUM
87 | LOSS:
88 | USE_OHEM: false
89 | OHEMTHRES: 0.9
90 | OHEMKEEP: 131072
91 | TRAIN:
92 | EE_ONLY: false
93 | ALLE_ONLY: false
94 | IMAGE_SIZE:
95 | - 1024
96 | - 512
97 | BASE_SIZE: 2048
98 | BATCH_SIZE_PER_GPU: 3
99 | SHUFFLE: true
100 | BEGIN_EPOCH: 0
101 | END_EPOCH: 484
102 | RESUME: false
103 | OPTIMIZER: sgd
104 | LR: 0.01
105 | WD: 0.0005
106 | MOMENTUM: 0.9
107 | NESTEROV: false
108 | FLIP: true
109 | MULTI_SCALE: true
110 | DOWNSAMPLERATE: 1
111 | IGNORE_LABEL: 255
112 | SCALE_FACTOR: 16
113 | TEST:
114 | SUB_DIR: ''
115 | IMAGE_SIZE:
116 | - 2048
117 | - 1024
118 | BASE_SIZE: 2048
119 | BATCH_SIZE_PER_GPU: 4
120 | CENTER_CROP_TEST: false
121 |
--------------------------------------------------------------------------------
/lib/config/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from .default import _C as config
6 | from .default import update_config
7 | from .models import MODEL_EXTRAS
8 |
9 |
10 |
--------------------------------------------------------------------------------
/lib/config/default.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 |
7 | from yacs.config import CfgNode as CN
8 |
9 |
10 | _C = CN(new_allowed=True)
11 |
12 | _C.OUTPUT_DIR = ''
13 | _C.LOG_DIR = ''
14 | _C.GPUS = (0,)
15 | _C.WORKERS = 4
16 | _C.PRINT_FREQ = 20
17 | _C.AUTO_RESUME = False
18 | _C.PIN_MEMORY = True
19 | _C.RANK = 0
20 |
21 | _C.CUDNN = CN()
22 | _C.CUDNN.BENCHMARK = True
23 | _C.CUDNN.DETERMINISTIC = False
24 | _C.CUDNN.ENABLED = True
25 |
26 | _C.MODEL = CN(new_allowed=True)
27 | _C.MODEL.NAME = 'seg_hrnet'
28 | _C.MODEL.PRETRAINED = ''
29 | _C.MODEL.LOAD_STAGE = 0
30 | _C.MODEL.EXTRA = CN(new_allowed=True)
31 |
32 | _C.LOSS = CN()
33 | _C.LOSS.USE_OHEM = False
34 | _C.LOSS.OHEMTHRES = 0.9
35 | _C.LOSS.OHEMKEEP = 100000
36 | _C.LOSS.CLASS_BALANCE = True
37 |
38 | _C.DATASET = CN()
39 | _C.DATASET.ROOT = ''
40 | _C.DATASET.DATASET = 'cityscapes'
41 | _C.DATASET.NUM_CLASSES = 19
42 | _C.DATASET.TRAIN_SET = 'list/cityscapes/train.lst'
43 | _C.DATASET.EXTRA_TRAIN_SET = ''
44 | _C.DATASET.TEST_SET = 'list/cityscapes/val.lst'
45 |
46 | _C.TRAIN = CN(new_allowed=True)
47 |
48 | _C.TRAIN.IMAGE_SIZE = [1024, 512]
49 | _C.TRAIN.BASE_SIZE = 2048
50 | _C.TRAIN.DOWNSAMPLERATE = 1
51 | _C.TRAIN.FLIP = True
52 | _C.TRAIN.MULTI_SCALE = True
53 | _C.TRAIN.SCALE_FACTOR = 16
54 |
55 | _C.TRAIN.LR_FACTOR = 0.1
56 | _C.TRAIN.LR_STEP = [90, 110]
57 | _C.TRAIN.LR = 0.01
58 | _C.TRAIN.EXTRA_LR = 0.001
59 |
60 | _C.TRAIN.OPTIMIZER = 'sgd'
61 | _C.TRAIN.MOMENTUM = 0.9
62 | _C.TRAIN.WD = 0.0001
63 | _C.TRAIN.NESTEROV = False
64 | _C.TRAIN.IGNORE_LABEL = -1
65 |
66 | _C.TRAIN.BEGIN_EPOCH = 0
67 | _C.TRAIN.END_EPOCH = 484
68 | _C.TRAIN.EXTRA_EPOCH = 0
69 |
70 | _C.TRAIN.RESUME = False
71 |
72 | _C.TRAIN.BATCH_SIZE_PER_GPU = 32
73 | _C.TRAIN.SHUFFLE = True
74 | _C.TRAIN.NUM_SAMPLES = 0
75 |
76 | _C.TEST = CN(new_allowed=True)
77 |
78 | _C.TEST.IMAGE_SIZE = [2048, 1024]
79 | _C.TEST.BASE_SIZE = 2048
80 |
81 | _C.TEST.BATCH_SIZE_PER_GPU = 32
82 | _C.TEST.NUM_SAMPLES = 0
83 |
84 | _C.TEST.MODEL_FILE = ''
85 | _C.TEST.FLIP_TEST = False
86 | _C.TEST.MULTI_SCALE = False
87 | _C.TEST.CENTER_CROP_TEST = False
88 | _C.TEST.SCALE_LIST = [1]
89 |
90 | _C.DEBUG = CN()
91 | _C.DEBUG.DEBUG = False
92 | _C.DEBUG.SAVE_BATCH_IMAGES_GT = False
93 | _C.DEBUG.SAVE_BATCH_IMAGES_PRED = False
94 | _C.DEBUG.SAVE_HEATMAPS_GT = False
95 | _C.DEBUG.SAVE_HEATMAPS_PRED = False
96 |
97 | _C.EXIT = CN(new_allowed=True)
98 |
99 | _C.EXIT.TYPE = 'original'
100 | _C.EXIT.FINAL_CONV_KERNEL = 1
101 | _C.EXIT.COMP_RATE = 1.0
102 | _C.EXIT.SMOOTH = False
103 | _C.EXIT.SMOOTH_KS = 3
104 | _C.EXIT.LAST_SAME = False
105 | _C.EXIT.FIX_INTER_CHANNEL = False
106 | _C.EXIT.INTER_CHANNEL = 64
107 |
108 | _C.MASK = CN(new_allowed=True)
109 | _C.MASK.ENTROPY_THRE = 0.0
110 |
111 | _C.PYRAMID_TEST = CN(new_allowed=True)
112 | _C.PYRAMID_TEST.USE = False
113 | _C.PYRAMID_TEST.SIZE = 512
114 |
115 |
116 |
117 | def update_config(cfg, args):
118 | cfg.defrost()
119 |
120 | cfg.merge_from_file(args.cfg)
121 | cfg.merge_from_list(args.opts)
122 |
123 | cfg.freeze()
124 |
125 |
126 | if __name__ == '__main__':
127 | import sys
128 | with open(sys.argv[1], 'w') as f:
129 | print(_C, file=f)
130 |
131 |
--------------------------------------------------------------------------------
/lib/config/models.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from yacs.config import CfgNode as CN
6 |
7 | HIGH_RESOLUTION_NET = CN()
8 | HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*']
9 | HIGH_RESOLUTION_NET.STEM_INPLANES = 64
10 | HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1
11 | HIGH_RESOLUTION_NET.WITH_HEAD = True
12 |
13 | HIGH_RESOLUTION_NET.STAGE1 = CN()
14 | HIGH_RESOLUTION_NET.STAGE1.NUM_MODULES = 1
15 | HIGH_RESOLUTION_NET.STAGE1.NUM_BRANCHES = 1
16 | HIGH_RESOLUTION_NET.STAGE1.NUM_BLOCKS = [4]
17 | HIGH_RESOLUTION_NET.STAGE1.NUM_CHANNELS = [32]
18 | HIGH_RESOLUTION_NET.STAGE1.BLOCK = 'BASIC'
19 | HIGH_RESOLUTION_NET.STAGE1.FUSE_METHOD = 'SUM'
20 |
21 | HIGH_RESOLUTION_NET.STAGE2 = CN()
22 | HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1
23 | HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2
24 | HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4]
25 | HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64]
26 | HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC'
27 | HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM'
28 |
29 | HIGH_RESOLUTION_NET.STAGE3 = CN()
30 | HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1
31 | HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3
32 | HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4]
33 | HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128]
34 | HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC'
35 | HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM'
36 |
37 | HIGH_RESOLUTION_NET.STAGE4 = CN()
38 | HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1
39 | HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4
40 | HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
41 | HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
42 | HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC'
43 | HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM'
44 |
45 | MODEL_EXTRAS = {
46 | 'seg_hrnet': HIGH_RESOLUTION_NET,
47 | }
48 |
--------------------------------------------------------------------------------
/lib/core/criterion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 | class CrossEntropy(nn.Module):
6 | def __init__(self, ignore_label=-1, weight=None):
7 | super(CrossEntropy, self).__init__()
8 | self.ignore_label = ignore_label
9 | self.criterion = nn.CrossEntropyLoss(weight=weight,
10 | ignore_index=ignore_label)
11 |
12 | def forward(self, score, target):
13 | ph, pw = score.size(2), score.size(3)
14 | h, w = target.size(1), target.size(2)
15 | if ph != h or pw != w:
16 | score = F.upsample(
17 | input=score, size=(h, w), mode='bilinear')
18 |
19 | loss = self.criterion(score, target)
20 |
21 | return loss
22 |
23 |
--------------------------------------------------------------------------------
/lib/core/function.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import time
4 |
5 | import numpy as np
6 | import numpy.ma as ma
7 | from tqdm import tqdm
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.distributed as dist
12 | from torch.nn import functional as F
13 |
14 | from utils.utils import AverageMeter
15 | from utils.utils import get_confusion_matrix_gpu
16 | from utils.utils import adjust_learning_rate
17 | from utils.utils import get_world_size, get_rank
18 | from utils.modelsummary import get_model_summary
19 |
20 | import pdb
21 | from PIL import Image
22 | import cv2
23 | import time
24 |
25 | def reduce_tensor(inp):
26 | world_size = get_world_size()
27 | if world_size < 2:
28 | return inp
29 | with torch.no_grad():
30 | reduced_inp = inp
31 | dist.reduce(reduced_inp, dst=0)
32 | return reduced_inp
33 |
34 |
35 | def train_ee(config, epoch, num_epoch, epoch_iters, base_lr, num_iters,
36 | trainloader, optimizer, model, writer_dict, device):
37 |
38 | model.train()
39 | torch.manual_seed(get_rank() + epoch * 123)
40 |
41 | if config.TRAIN.EE_ONLY or config.TRAIN.ALLE_ONLY:
42 | model.eval()
43 | model.module.model.exit1.train()
44 | model.module.model.exit2.train()
45 | model.module.model.exit3.train()
46 | if config.TRAIN.ALLE_ONLY:
47 | model.module.model.last_layer.train()
48 |
49 |
50 | data_time = AverageMeter()
51 | batch_time = AverageMeter()
52 | ave_loss = AverageMeter()
53 |
54 | tic_data = time.time()
55 | tic = time.time()
56 | tic_total = time.time()
57 | cur_iters = epoch*epoch_iters
58 | writer = writer_dict['writer']
59 | global_steps = writer_dict['train_global_steps']
60 | rank = get_rank()
61 | world_size = get_world_size()
62 |
63 | for i_iter, batch in enumerate(trainloader):
64 | data_time.update(time.time() - tic_data)
65 |
66 |
67 | images, labels, _, _ = batch
68 | images = images.to(device)
69 | labels = labels.long().to(device)
70 |
71 | losses, _ = model(images, labels)
72 |
73 | loss = 0
74 | reduced_losses = []
75 | for i, l in enumerate(losses):
76 | loss += config.MODEL.EXTRA.EE_WEIGHTS[i] * losses[i]
77 | reduced_losses.append(reduce_tensor(losses[i]))
78 | reduced_loss = reduce_tensor(loss)
79 |
80 | model.zero_grad()
81 | loss.backward()
82 | optimizer.step()
83 |
84 |
85 | ave_loss.update(reduced_loss.item())
86 |
87 | lr = adjust_learning_rate(optimizer,
88 | base_lr,
89 | num_iters,
90 | i_iter+cur_iters)
91 |
92 | batch_time.update(time.time() - tic)
93 | tic = time.time()
94 |
95 |
96 | if i_iter % config.PRINT_FREQ == 0 and rank == 0:
97 |
98 | print_loss = reduced_loss / world_size
99 | msg = 'Epoch: [{: >3d}/{}] Iter:[{: >3d}/{}], Time: {:.2f}, Data Time: {:.2f} ' \
100 | 'lr: {:.6f}, Loss: {:.6f}' .format(
101 | epoch, num_epoch, i_iter, epoch_iters,
102 | batch_time.average(), data_time.average(), lr, print_loss)
103 | logging.info(msg)
104 |
105 | global_steps = writer_dict['train_global_steps']
106 | writer.add_scalar('train_loss', print_loss, global_steps)
107 |
108 | writer.add_scalars('exit_train_loss', {
109 | 'exit1': reduced_losses[0].item() / world_size,
110 | 'exit2': reduced_losses[1].item() / world_size,
111 | 'exit3': reduced_losses[2].item() / world_size,
112 | 'exit4': reduced_losses[3].item() / world_size,
113 | },
114 | global_steps)
115 |
116 | writer_dict['train_global_steps'] += 1
117 |
118 | tic_data = time.time()
119 |
120 | train_time = time.time() - tic_total
121 |
122 | if rank == 0:
123 | logging.info(f'Train time:{train_time}s')
124 |
125 | def validate_ee(config, testloader, model, writer_dict, device):
126 |
127 | torch.manual_seed(get_rank())
128 |
129 | tic_data = time.time()
130 | tic = time.time()
131 | tic_total = time.time()
132 | rank = get_rank()
133 | world_size = get_world_size()
134 | model.eval()
135 |
136 | data_time = AverageMeter()
137 | batch_time = AverageMeter()
138 | ave_loss = AverageMeter()
139 |
140 | num_exits = len(config.MODEL.EXTRA.EE_WEIGHTS)
141 |
142 | ave_losses = [AverageMeter() for i in range(num_exits)]
143 |
144 | confusion_matrices = [np.zeros((config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES)) for i in range(num_exits)]
145 |
146 |
147 | with torch.no_grad():
148 | for i_iter, batch in enumerate(testloader):
149 | data_time.update(time.time() - tic_data)
150 |
151 | image, label, _, _ = batch
152 | size = label.size()
153 | image = image.to(device)
154 | label = label.long().to(device)
155 |
156 | losses, preds = model(image, label)
157 |
158 | for i, pred in enumerate(preds):
159 | if pred.size()[-2] != size[-2] or pred.size()[-1] != size[-1]:
160 | pred = F.upsample(pred, (size[-2], size[-1]),
161 | mode='bilinear')
162 |
163 | confusion_matrices[i] += get_confusion_matrix_gpu(
164 | label,
165 | pred,
166 | size,
167 | config.DATASET.NUM_CLASSES,
168 | config.TRAIN.IGNORE_LABEL)
169 |
170 | loss = 0
171 | reduced_losses = []
172 | for i, l in enumerate(losses):
173 | loss += config.MODEL.EXTRA.EE_WEIGHTS[i] * losses[i]
174 | reduced_losses.append(reduce_tensor(losses[i]))
175 | ave_losses[i].update(reduced_losses[i].item())
176 |
177 | reduced_loss = reduce_tensor(loss)
178 | ave_loss.update(reduced_loss.item())
179 |
180 | batch_time.update(time.time() - tic)
181 | tic = time.time()
182 |
183 | tic_data = time.time()
184 |
185 | if i_iter % config.PRINT_FREQ == 0 and rank == 0:
186 | print_loss = ave_loss.average() / world_size
187 | msg = 'Iter:[{: >3d}/{}], Time: {:.2f}, Data Time: {:.2f} ' \
188 | 'Loss: {:.6f}' .format(
189 | i_iter, len(testloader), batch_time.average(), data_time.average(), print_loss)
190 | logging.info(msg)
191 |
192 |
193 | results = []
194 | for i, confusion_matrix in enumerate(confusion_matrices):
195 |
196 | confusion_matrix = torch.from_numpy(confusion_matrix).to(device)
197 | reduced_confusion_matrix = reduce_tensor(confusion_matrix)
198 | confusion_matrix = reduced_confusion_matrix.cpu().numpy()
199 |
200 | pos = confusion_matrix.sum(1)
201 | res = confusion_matrix.sum(0)
202 | tp = np.diag(confusion_matrix)
203 | pixel_acc = tp.sum()/pos.sum()
204 | mean_acc = (tp/np.maximum(1.0, pos)).mean()
205 | IoU_array = (tp / np.maximum(1.0, pos + res - tp))
206 | mean_IoU = IoU_array.mean()
207 |
208 | results.append((mean_IoU, IoU_array, pixel_acc, mean_acc))
209 |
210 | val_time = time.time() - tic_total
211 |
212 | if rank == 0:
213 | logging.info(f'Validation time:{val_time}s')
214 | mean_IoUs = [result[0] for result in results]
215 | mean_IoUs.append(np.mean(mean_IoUs))
216 | print_result = '\t'.join(['{:.2f}'.format(m*100) for m in mean_IoUs])
217 | logging.info(f'mean_IoUs: {print_result}')
218 |
219 | writer = writer_dict['writer']
220 | global_steps = writer_dict['valid_global_steps']
221 | writer.add_scalar('valid_loss', print_loss, global_steps)
222 |
223 | writer.add_scalars('exit_valid_loss', {
224 | 'exit1': ave_losses[0].average() / world_size,
225 | 'exit2': ave_losses[1].average() / world_size,
226 | 'exit3': ave_losses[2].average() / world_size,
227 | 'exit4': ave_losses[3].average() / world_size,
228 | },
229 | global_steps)
230 |
231 | writer.add_scalars('valid_mIoUs',
232 | {f'valid_mIoU{i+1}': results[i][0] for i in range(num_exits)},
233 | global_steps
234 | )
235 | writer_dict['valid_global_steps'] += 1
236 |
237 | return results
238 |
239 |
240 | VIS_T = False
241 | VIS = False
242 | VIS_CONF = False
243 | TIMING = True
244 |
245 | def testval_ee(config, test_dataset, testloader, model,
246 | sv_dir='', sv_pred=False):
247 | model.eval()
248 | torch.manual_seed(get_rank())
249 | num_exits = len(config.MODEL.EXTRA.EE_WEIGHTS)
250 |
251 | confusion_matrices = [np.zeros((config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES)) for i in range(num_exits)]
252 |
253 | total_time = 0
254 |
255 | with torch.no_grad():
256 | for index, batch in enumerate(tqdm(testloader)):
257 | image, label, _, name = batch
258 | if config.PYRAMID_TEST.USE:
259 | image = F.interpolate(image, (config.PYRAMID_TEST.SIZE//2, config.PYRAMID_TEST.SIZE), mode='bilinear')
260 |
261 | size = label.size()
262 |
263 | if TIMING:
264 | start = time.time()
265 | torch.cuda.synchronize()
266 | preds = model(image)
267 |
268 | if TIMING:
269 | torch.cuda.synchronize()
270 | total_time += time.time() - start
271 |
272 | for i, pred in enumerate(preds):
273 | if pred.size()[-2] != size[-2] or pred.size()[-1] != size[-1]:
274 | original_logits = pred
275 | pred = F.upsample(pred, (size[-2], size[-1]),
276 | mode='bilinear')
277 |
278 | confusion_matrices[i] += get_confusion_matrix_gpu(
279 | label,
280 | pred,
281 | size,
282 | config.DATASET.NUM_CLASSES,
283 | config.TRAIN.IGNORE_LABEL)
284 |
285 | if sv_pred and index % 20 == 0 and VIS:
286 | print("Saving ... ", name)
287 | sv_path = os.path.join(sv_dir, f'test_val_results/{i+1}')
288 | os.makedirs(sv_path, exist_ok=True)
289 | test_dataset.save_pred(pred, sv_path, name)
290 |
291 | if VIS_T or VIS_CONF:
292 | def save_float_img(t, sv_path, name, normalize=False):
293 | os.makedirs(sv_path, exist_ok=True)
294 | if normalize:
295 | t = t/t.max()
296 | torch.save(t, os.path.join(sv_path, name[0]+'.pth'))
297 | t = t[0][0]
298 | t = t.cpu().numpy().copy()
299 | np.save(os.path.join(sv_path, name[0]+'.npy'), t)
300 | cv2.imwrite(os.path.join(sv_path, name[0]+'.png'), t*255)
301 |
302 | def save_long_img(t, sv_path, name):
303 | os.makedirs(sv_path, exist_ok=True)
304 | t = t[0][0]
305 | t = t.cpu().numpy().copy()
306 | cv2.imwrite(os.path.join(sv_path, name[0]+'.png'), t)
307 |
308 | def save_tensor(t, sv_path, name):
309 | os.makedirs(sv_path, exist_ok=True)
310 | torch.save(t, os.path.join(sv_path, name[0]+'.pth'))
311 |
312 |
313 | if VIS_CONF:
314 |
315 | out = F.softmax(original_logits, dim=1)
316 |
317 | sv_path = os.path.join(sv_dir, f'test_val_original_conf/{i+1}')
318 | original_conf_map, _ = out.max(dim=1)
319 | save_float_img(original_conf_map.unsqueeze(0), sv_path, name, normalize=False)
320 |
321 | sv_path = os.path.join(sv_dir, f'test_val_original_pred/{i+1}')
322 | max_index = torch.max(out, dim=1)[1]
323 | save_long_img(max_index.unsqueeze(0), sv_path, name)
324 |
325 | sv_path = os.path.join(sv_dir, f'test_val_original_logits/{i+1}')
326 | save_tensor(original_logits, sv_path, name)
327 |
328 | sv_path = os.path.join(sv_dir, f'test_val_original_results/{i+1}')
329 | os.makedirs(sv_path, exist_ok=True)
330 | test_dataset.save_pred(original_logits, sv_path, name)
331 |
332 | if hasattr(model.module, 'mask_dict'):
333 | sv_path = os.path.join(sv_dir, f'test_val_masks/')
334 | os.makedirs(sv_path, exist_ok=True)
335 | torch.save(model.module.mask_dict, os.path.join(sv_path, name[0]+'.pth'))
336 |
337 | if i == 0:
338 | sv_path = os.path.join(sv_dir, f'test_val_gt/')
339 | save_long_img(label.unsqueeze(0), sv_path, name)
340 | if index % 100 == 0:
341 | logging.info(f'processing: {index} images with exit {i}')
342 | pos = confusion_matrices[i].sum(1)
343 | res = confusion_matrices[i].sum(0)
344 | tp = np.diag(confusion_matrices[i])
345 | IoU_array = (tp / np.maximum(1.0, pos + res - tp))
346 | mean_IoU = IoU_array.mean()
347 | logging.info('mIoU: %.4f' % (mean_IoU))
348 |
349 | results = []
350 | for i, confusion_matrix in enumerate(confusion_matrices):
351 | pos = confusion_matrix.sum(1)
352 | res = confusion_matrix.sum(0)
353 | tp = np.diag(confusion_matrix)
354 | pixel_acc = tp.sum()/pos.sum()
355 | mean_acc = (tp/np.maximum(1.0, pos)).mean()
356 | IoU_array = (tp / np.maximum(1.0, pos + res - tp))
357 | mean_IoU = IoU_array.mean()
358 |
359 | results.append((mean_IoU, IoU_array, pixel_acc, mean_acc))
360 |
361 | if TIMING:
362 | print("Total_time", total_time)
363 |
364 | return results
365 |
366 |
367 | def testval_ee_class(config, test_dataset, testloader, model,
368 | sv_dir='', sv_pred=False):
369 | model.eval()
370 | torch.manual_seed(get_rank())
371 | num_exits = len(config.MODEL.EXTRA.EE_WEIGHTS)
372 |
373 | confusion_matrices = [np.zeros((config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES)) for i in range(num_exits)]
374 |
375 | total_time = 0
376 |
377 | with torch.no_grad():
378 | for index, batch in enumerate(tqdm(testloader)):
379 | image, label, _, name = batch
380 |
381 | size = label.size()
382 | preds = model(image)
383 |
384 | for i, pred in enumerate(preds):
385 | if pred.size()[-2] != size[-2] or pred.size()[-1] != size[-1]:
386 | original_logits = pred
387 | pred = F.upsample(pred, (size[-2], size[-1]),
388 | mode='bilinear')
389 |
390 | confusion_matrices[i] += get_confusion_matrix_gpu(
391 | label,
392 | pred,
393 | size,
394 | config.DATASET.NUM_CLASSES,
395 | config.TRAIN.IGNORE_LABEL)
396 |
397 | if sv_pred and index % 20 == 0 and VIS:
398 | print("Saving ... ", name)
399 | sv_path = os.path.join(sv_dir, f'test_val_results/{i+1}')
400 | os.makedirs(sv_path, exist_ok=True)
401 | test_dataset.save_pred(pred, sv_path, name)
402 |
403 | if VIS_T or VIS_CONF:
404 | def save_float_img(t, sv_path, name, normalize=False):
405 | os.makedirs(sv_path, exist_ok=True)
406 | if normalize:
407 | t = t/t.max()
408 | torch.save(t, os.path.join(sv_path, name[0]+'.pth'))
409 | t = t[0][0]
410 | t = t.cpu().numpy().copy()
411 | np.save(os.path.join(sv_path, name[0]+'.npy'), t)
412 | cv2.imwrite(os.path.join(sv_path, name[0]+'.png'), t*255)
413 |
414 | def save_long_img(t, sv_path, name):
415 | os.makedirs(sv_path, exist_ok=True)
416 | t = t[0][0]
417 | t = t.cpu().numpy().copy()
418 | cv2.imwrite(os.path.join(sv_path, name[0]+'.png'), t)
419 |
420 | def save_tensor(t, sv_path, name):
421 | os.makedirs(sv_path, exist_ok=True)
422 | torch.save(t, os.path.join(sv_path, name[0]+'.pth'))
423 | if VIS_CONF:
424 | out = F.softmax(original_logits, dim=1)
425 |
426 | sv_path = os.path.join(sv_dir, f'test_val_original_conf/{i+1}')
427 | original_conf_map, _ = out.max(dim=1)
428 | save_float_img(original_conf_map.unsqueeze(0), sv_path, name, normalize=False)
429 |
430 | sv_path = os.path.join(sv_dir, f'test_val_original_pred/{i+1}')
431 | max_index = torch.max(out, dim=1)[1]
432 | save_long_img(max_index.unsqueeze(0), sv_path, name)
433 |
434 | sv_path = os.path.join(sv_dir, f'test_val_original_logits/{i+1}')
435 | save_tensor(original_logits, sv_path, name)
436 |
437 | sv_path = os.path.join(sv_dir, f'test_val_original_results/{i+1}')
438 | os.makedirs(sv_path, exist_ok=True)
439 | test_dataset.save_pred(original_logits, sv_path, name)
440 |
441 | if hasattr(model.module, 'mask_dict'):
442 | sv_path = os.path.join(sv_dir, f'test_val_masks/')
443 | os.makedirs(sv_path, exist_ok=True)
444 | torch.save(model.module.mask_dict, os.path.join(sv_path, name[0]+'.pth'))
445 |
446 | if i == 0:
447 | sv_path = os.path.join(sv_dir, f'test_val_gt/')
448 | save_long_img(label.unsqueeze(0), sv_path, name)
449 |
450 | if index % 100 == 0:
451 | logging.info(f'processing: {index} images with exit {i}')
452 | pos = confusion_matrices[i].sum(1)
453 | res = confusion_matrices[i].sum(0)
454 | tp = np.diag(confusion_matrices[i])
455 | IoU_array = (tp / np.maximum(1.0, pos + res - tp))
456 | mean_IoU = IoU_array.mean()
457 | logging.info('mIoU: %.4f' % (mean_IoU))
458 |
459 | results = []
460 | for i, confusion_matrix in enumerate(confusion_matrices):
461 | pos = confusion_matrix.sum(1)
462 | res = confusion_matrix.sum(0)
463 | tp = np.diag(confusion_matrix)
464 | pixel_acc = tp.sum()/pos.sum()
465 | mean_acc = (tp/np.maximum(1.0, pos)).mean()
466 | IoU_array = (tp / np.maximum(1.0, pos + res - tp))
467 | mean_IoU = IoU_array.mean()
468 |
469 | results.append((mean_IoU, IoU_array, pixel_acc, mean_acc))
470 |
471 | if TIMING:
472 | print("Total_time", total_time)
473 |
474 | return results
475 | def testval_ee_profiling(config, test_dataset, testloader, model,
476 | sv_dir='', sv_pred=False):
477 | model.eval()
478 | torch.manual_seed(get_rank())
479 | num_exits = len(config.MODEL.EXTRA.EE_WEIGHTS)
480 | total_time = 0
481 |
482 | gflops = []
483 | with torch.no_grad():
484 | for index, batch in enumerate(tqdm(testloader)):
485 | image, label, _, name = batch
486 | if config.PYRAMID_TEST.USE:
487 | image = F.interpolate(image, (config.PYRAMID_TEST.SIZE, config.PYRAMID_TEST.SIZE//2), mode='bilinear')
488 | stats = {}
489 | saved_stats = {}
490 |
491 | for i in range(4):
492 | setattr(model.module, f"stop{i+1}", "anY_RanDOM_ThiNg")
493 | summary, stats[i+1] = get_model_summary(model, image, verbose=False)
494 | delattr(model.module, f"stop{i+1}")
495 |
496 | saved_stats['params'] = [stats[i+1]['params'] for i in range(4)]
497 | saved_stats['flops'] = [stats[i+1]['flops'] for i in range(4)]
498 | saved_stats['counts'] = [stats[i+1]['counts'] for i in range(4)]
499 | saved_stats['Gflops'] = [f/(1024**3) for f in saved_stats['flops']]
500 | saved_stats['Mparams'] = [f/(10**6) for f in saved_stats['params']]
501 | gflops.append(saved_stats['Gflops'])
502 |
503 | final_stats = saved_stats
504 | final_stats['Gflops'] = []
505 | for i in range(4):
506 | final_stats['Gflops'].append(np.mean([x[i] for x in gflops]))
507 | final_stats['Gflops_mean'] = np.mean(final_stats['Gflops'])
508 | return final_stats
509 |
510 | def testval_ee_profiling_actual(config, test_dataset, testloader, model,
511 | sv_dir='', sv_pred=False):
512 | model.eval()
513 | torch.manual_seed(get_rank())
514 | num_exits = len(config.MODEL.EXTRA.EE_WEIGHTS)
515 | total_time = 0
516 |
517 | stats = {}
518 | stats['time'] = {}
519 | times = []
520 |
521 | with torch.no_grad():
522 | for index, batch in enumerate(tqdm(testloader)):
523 | image, label, _, name = batch
524 | t = []
525 | for i in range(4):
526 | if isinstance(model, nn.DataParallel):
527 | setattr(model.module, f"stop{i+1}", "anY_RanDOM_ThiNg")
528 | else:
529 | setattr(model, f"stop{i+1}", "anY_RanDOM_ThiNg")
530 |
531 | torch.cuda.synchronize()
532 | start = time.time()
533 | out = model(image)
534 | torch.cuda.synchronize()
535 | t.append(time.time() - start)
536 |
537 | if isinstance(model, nn.DataParallel):
538 | delattr(model.module, f"stop{i+1}")
539 | else:
540 | delattr(model, f"stop{i+1}")
541 |
542 | if index > 5:
543 | times.append(t)
544 | if index > 20:
545 | break
546 |
547 | print(t)
548 | for i in range(4):
549 | stats['time'][i] = np.mean([t[i] for t in times])
550 | print(stats)
551 | return stats
552 |
--------------------------------------------------------------------------------
/lib/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from .cityscapes import Cityscapes as cityscapes
--------------------------------------------------------------------------------
/lib/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 | import random
6 |
7 | import torch
8 | from torch.nn import functional as F
9 | from torch.utils import data
10 |
11 | class BaseDataset(data.Dataset):
12 | def __init__(self,
13 | ignore_label=-1,
14 | base_size=2048,
15 | crop_size=(512, 1024),
16 | downsample_rate=1,
17 | scale_factor=16,
18 | mean=[0.485, 0.456, 0.406],
19 | std=[0.229, 0.224, 0.225]):
20 |
21 | self.base_size = base_size
22 | self.crop_size = crop_size
23 | self.ignore_label = ignore_label
24 |
25 | self.mean = mean
26 | self.std = std
27 | self.scale_factor = scale_factor
28 | self.downsample_rate = 1./downsample_rate
29 |
30 | self.files = []
31 |
32 | def __len__(self):
33 | return len(self.files)
34 |
35 | def input_transform(self, image):
36 | image = image.astype(np.float32)[:, :, ::-1]
37 | image = image / 255.0
38 | image -= self.mean
39 | image /= self.std
40 | return image
41 |
42 | def label_transform(self, label):
43 | return np.array(label).astype('int32')
44 |
45 | def pad_image(self, image, h, w, size, padvalue):
46 | pad_image = image.copy()
47 | pad_h = max(size[0] - h, 0)
48 | pad_w = max(size[1] - w, 0)
49 | if pad_h > 0 or pad_w > 0:
50 | pad_image = cv2.copyMakeBorder(image, 0, pad_h, 0,
51 | pad_w, cv2.BORDER_CONSTANT,
52 | value=padvalue)
53 |
54 | return pad_image
55 |
56 | def rand_crop(self, image, label):
57 | h, w = image.shape[:-1]
58 | image = self.pad_image(image, h, w, self.crop_size,
59 | (0.0, 0.0, 0.0))
60 | label = self.pad_image(label, h, w, self.crop_size,
61 | (self.ignore_label,))
62 |
63 | new_h, new_w = label.shape
64 | x = random.randint(0, new_w - self.crop_size[1])
65 | y = random.randint(0, new_h - self.crop_size[0])
66 | image = image[y:y+self.crop_size[0], x:x+self.crop_size[1]]
67 | label = label[y:y+self.crop_size[0], x:x+self.crop_size[1]]
68 |
69 | return image, label
70 |
71 | def center_crop(self, image, label):
72 | h, w = image.shape[:2]
73 | x = int(round((w - self.crop_size[1]) / 2.))
74 | y = int(round((h - self.crop_size[0]) / 2.))
75 | image = image[y:y+self.crop_size[0], x:x+self.crop_size[1]]
76 | label = label[y:y+self.crop_size[0], x:x+self.crop_size[1]]
77 |
78 | return image, label
79 |
80 | def image_resize(self, image, long_size, label=None):
81 | h, w = image.shape[:2]
82 | if h > w:
83 | new_h = long_size
84 | new_w = np.int(w * long_size / h + 0.5)
85 | else:
86 | new_w = long_size
87 | new_h = np.int(h * long_size / w + 0.5)
88 |
89 | image = cv2.resize(image, (new_w, new_h),
90 | interpolation = cv2.INTER_LINEAR)
91 | if label is not None:
92 | label = cv2.resize(label, (new_w, new_h),
93 | interpolation = cv2.INTER_NEAREST)
94 | else:
95 | return image
96 |
97 | return image, label
98 |
99 | def multi_scale_aug(self, image, label=None,
100 | rand_scale=1, rand_crop=True):
101 | long_size = np.int(self.base_size * rand_scale + 0.5)
102 | if label is not None:
103 | image, label = self.image_resize(image, long_size, label)
104 | if rand_crop:
105 | image, label = self.rand_crop(image, label)
106 | return image, label
107 | else:
108 | image = self.image_resize(image, long_size)
109 | return image
110 |
111 | def gen_sample(self, image, label,
112 | multi_scale=True, is_flip=True, center_crop_test=False):
113 | if multi_scale:
114 | rand_scale = 0.5 + random.randint(0, self.scale_factor) / 10.0
115 | image, label = self.multi_scale_aug(image, label,
116 | rand_scale=rand_scale)
117 |
118 | if center_crop_test:
119 | image, label = self.image_resize(image,
120 | self.base_size,
121 | label)
122 | image, label = self.center_crop(image, label)
123 |
124 | image = self.input_transform(image)
125 | label = self.label_transform(label)
126 |
127 | image = image.transpose((2, 0, 1))
128 |
129 | if is_flip:
130 | flip = np.random.choice(2) * 2 - 1
131 | image = image[:, :, ::flip]
132 | label = label[:, ::flip]
133 |
134 | if self.downsample_rate != 1:
135 | label = cv2.resize(label,
136 | None,
137 | fx=self.downsample_rate,
138 | fy=self.downsample_rate,
139 | interpolation=cv2.INTER_NEAREST)
140 |
141 | return image, label
142 |
143 | def inference(self, model, image, flip=False):
144 | size = image.size()
145 | pred = model(image)
146 | pred = F.upsample(input=pred,
147 | size=(size[-2], size[-1]),
148 | mode='bilinear')
149 | if flip:
150 | flip_img = image.numpy()[:,:,:,::-1]
151 | flip_output = model(torch.from_numpy(flip_img.copy()))
152 | flip_output = F.upsample(input=flip_output,
153 | size=(size[-2], size[-1]),
154 | mode='bilinear')
155 | flip_pred = flip_output.cpu().numpy().copy()
156 | flip_pred = torch.from_numpy(flip_pred[:,:,:,::-1].copy()).cuda()
157 | pred += flip_pred
158 | pred = pred * 0.5
159 | return pred.exp()
160 |
161 | def multi_scale_inference(self, model, image, scales=[1], flip=False):
162 | batch, _, ori_height, ori_width = image.size()
163 | assert batch == 1, "only supporting batchsize 1."
164 | device = torch.device("cuda:%d" % model.device_ids[0])
165 | image = image.numpy()[0].transpose((1,2,0)).copy()
166 | stride_h = np.int(self.crop_size[0] * 2.0 / 3.0)
167 | stride_w = np.int(self.crop_size[1] * 2.0 / 3.0)
168 | final_pred = torch.zeros([1, self.num_classes,
169 | ori_height,ori_width]).to(device)
170 | padvalue = -1.0 * np.array(self.mean) / np.array(self.std)
171 | for scale in scales:
172 | new_img = self.multi_scale_aug(image=image,
173 | rand_scale=scale,
174 | rand_crop=False)
175 | height, width = new_img.shape[:-1]
176 |
177 | if max(height, width) <= np.min(self.crop_size):
178 | new_img = self.pad_image(new_img, height, width,
179 | self.crop_size, padvalue)
180 | new_img = new_img.transpose((2, 0, 1))
181 | new_img = np.expand_dims(new_img, axis=0)
182 | new_img = torch.from_numpy(new_img)
183 | preds = self.inference(model, new_img, flip)
184 | preds = preds[:, :, 0:height, 0:width]
185 | else:
186 | if height < self.crop_size[0] or width < self.crop_size[1]:
187 | new_img = self.pad_image(new_img, height, width,
188 | self.crop_size, padvalue)
189 | new_h, new_w = new_img.shape[:-1]
190 | rows = np.int(np.ceil(1.0 * (new_h -
191 | self.crop_size[0]) / stride_h)) + 1
192 | cols = np.int(np.ceil(1.0 * (new_w -
193 | self.crop_size[1]) / stride_w)) + 1
194 | preds = torch.zeros([1, self.num_classes,
195 | new_h,new_w]).to(device)
196 | count = torch.zeros([1,1, new_h, new_w]).to(device)
197 |
198 | for r in range(rows):
199 | for c in range(cols):
200 | h0 = r * stride_h
201 | w0 = c * stride_w
202 | h1 = min(h0 + self.crop_size[0], new_h)
203 | w1 = min(w0 + self.crop_size[1], new_w)
204 | crop_img = new_img[h0:h1, w0:w1, :]
205 | if h1 == new_h or w1 == new_w:
206 | crop_img = self.pad_image(crop_img,
207 | h1-h0,
208 | w1-w0,
209 | self.crop_size,
210 | padvalue)
211 | crop_img = crop_img.transpose((2, 0, 1))
212 | crop_img = np.expand_dims(crop_img, axis=0)
213 | crop_img = torch.from_numpy(crop_img)
214 | pred = self.inference(model, crop_img, flip)
215 |
216 | preds[:,:,h0:h1,w0:w1] += pred[:,:, 0:h1-h0, 0:w1-w0]
217 | count[:,:,h0:h1,w0:w1] += 1
218 | preds = preds / count
219 | preds = preds[:,:,:height,:width]
220 | preds = F.upsample(preds, (ori_height, ori_width),
221 | mode='bilinear')
222 | final_pred += preds
223 | return final_pred
224 |
--------------------------------------------------------------------------------
/lib/datasets/cityscapes.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 | from PIL import Image
6 |
7 | import torch
8 | from torch.nn import functional as F
9 |
10 | from .base_dataset import BaseDataset
11 |
12 | class Cityscapes(BaseDataset):
13 | def __init__(self,
14 | root,
15 | list_path,
16 | num_samples=None,
17 | num_classes=19,
18 | multi_scale=True,
19 | flip=True,
20 | ignore_label=-1,
21 | base_size=2048,
22 | crop_size=(512, 1024),
23 | center_crop_test=False,
24 | downsample_rate=1,
25 | scale_factor=16,
26 | mean=[0.485, 0.456, 0.406],
27 | std=[0.229, 0.224, 0.225]):
28 |
29 | super(Cityscapes, self).__init__(ignore_label, base_size,
30 | crop_size, downsample_rate, scale_factor, mean, std,)
31 |
32 | self.root = root
33 | self.list_path = list_path
34 | self.num_classes = num_classes
35 | self.class_weights = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345,
36 | 1.0166, 0.9969, 0.9754, 1.0489,
37 | 0.8786, 1.0023, 0.9539, 0.9843,
38 | 1.1116, 0.9037, 1.0865, 1.0955,
39 | 1.0865, 1.1529, 1.0507]).cuda()
40 |
41 | self.multi_scale = multi_scale
42 | self.flip = flip
43 | self.center_crop_test = center_crop_test
44 |
45 | self.img_list = [line.strip().split() for line in open(root+list_path)]
46 |
47 | self.files = self.read_files()
48 | if num_samples:
49 | self.files = self.files[:num_samples]
50 |
51 | self.label_mapping = {-1: ignore_label, 0: ignore_label,
52 | 1: ignore_label, 2: ignore_label,
53 | 3: ignore_label, 4: ignore_label,
54 | 5: ignore_label, 6: ignore_label,
55 | 7: 0, 8: 1, 9: ignore_label,
56 | 10: ignore_label, 11: 2, 12: 3,
57 | 13: 4, 14: ignore_label, 15: ignore_label,
58 | 16: ignore_label, 17: 5, 18: ignore_label,
59 | 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11,
60 | 25: 12, 26: 13, 27: 14, 28: 15,
61 | 29: ignore_label, 30: ignore_label,
62 | 31: 16, 32: 17, 33: 18}
63 |
64 | def read_files(self):
65 | files = []
66 | if 'test' in self.list_path:
67 | for item in self.img_list:
68 | image_path = item
69 | name = os.path.splitext(os.path.basename(image_path[0]))[0]
70 | files.append({
71 | "img": image_path[0],
72 | "name": name,
73 | })
74 | else:
75 | for item in self.img_list:
76 | image_path, label_path = item
77 | name = os.path.splitext(os.path.basename(label_path))[0]
78 | files.append({
79 | "img": image_path,
80 | "label": label_path,
81 | "name": name,
82 | "weight": 1
83 | })
84 | return files
85 |
86 | def convert_label(self, label, inverse=False):
87 | temp = label.copy()
88 | if inverse:
89 | for v, k in self.label_mapping.items():
90 | label[temp == k] = v
91 | else:
92 | for k, v in self.label_mapping.items():
93 | label[temp == k] = v
94 | return label
95 |
96 | def __getitem__(self, index):
97 | item = self.files[index]
98 | name = item["name"]
99 | image = cv2.imread(os.path.join(self.root,'cityscapes',item["img"]),
100 | cv2.IMREAD_COLOR)
101 | size = image.shape
102 |
103 | if 'test' in self.list_path:
104 | image = self.input_transform(image)
105 | image = image.transpose((2, 0, 1))
106 |
107 | return image.copy(), np.array(size), name
108 |
109 | label = cv2.imread(os.path.join(self.root,'cityscapes',item["label"]),
110 | cv2.IMREAD_GRAYSCALE)
111 | label = self.convert_label(label)
112 |
113 | image, label = self.gen_sample(image, label,
114 | self.multi_scale, self.flip,
115 | self.center_crop_test)
116 |
117 | return image.copy(), label.copy(), np.array(size), name
118 |
119 | def multi_scale_inference(self, model, image, scales=[1], flip=False):
120 | batch, _, ori_height, ori_width = image.size()
121 | assert batch == 1, "only supporting batchsize 1."
122 | image = image.numpy()[0].transpose((1,2,0)).copy()
123 | stride_h = np.int(self.crop_size[0] * 1.0)
124 | stride_w = np.int(self.crop_size[1] * 1.0)
125 | final_pred = torch.zeros([1, self.num_classes,
126 | ori_height,ori_width]).cuda()
127 | for scale in scales:
128 | new_img = self.multi_scale_aug(image=image,
129 | rand_scale=scale,
130 | rand_crop=False)
131 | height, width = new_img.shape[:-1]
132 |
133 | if scale <= 1.0:
134 | new_img = new_img.transpose((2, 0, 1))
135 | new_img = np.expand_dims(new_img, axis=0)
136 | new_img = torch.from_numpy(new_img)
137 | preds = self.inference(model, new_img, flip)
138 | preds = preds[:, :, 0:height, 0:width]
139 | else:
140 | new_h, new_w = new_img.shape[:-1]
141 | rows = np.int(np.ceil(1.0 * (new_h -
142 | self.crop_size[0]) / stride_h)) + 1
143 | cols = np.int(np.ceil(1.0 * (new_w -
144 | self.crop_size[1]) / stride_w)) + 1
145 | preds = torch.zeros([1, self.num_classes,
146 | new_h,new_w]).cuda()
147 | count = torch.zeros([1,1, new_h, new_w]).cuda()
148 |
149 | for r in range(rows):
150 | for c in range(cols):
151 | h0 = r * stride_h
152 | w0 = c * stride_w
153 | h1 = min(h0 + self.crop_size[0], new_h)
154 | w1 = min(w0 + self.crop_size[1], new_w)
155 | h0 = max(int(h1 - self.crop_size[0]), 0)
156 | w0 = max(int(w1 - self.crop_size[1]), 0)
157 | crop_img = new_img[h0:h1, w0:w1, :]
158 | crop_img = crop_img.transpose((2, 0, 1))
159 | crop_img = np.expand_dims(crop_img, axis=0)
160 | crop_img = torch.from_numpy(crop_img)
161 | pred = self.inference(model, crop_img, flip)
162 | preds[:,:,h0:h1,w0:w1] += pred[:,:, 0:h1-h0, 0:w1-w0]
163 | count[:,:,h0:h1,w0:w1] += 1
164 | preds = preds / count
165 | preds = preds[:,:,:height,:width]
166 | preds = F.upsample(preds, (ori_height, ori_width),
167 | mode='bilinear')
168 | final_pred += preds
169 | return final_pred
170 |
171 | def get_palette(self, n):
172 | palette = [0] * (n * 3)
173 | for j in range(0, n):
174 | lab = j
175 | palette[j * 3 + 0] = 0
176 | palette[j * 3 + 1] = 0
177 | palette[j * 3 + 2] = 0
178 | i = 0
179 | while lab:
180 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
181 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
182 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
183 | i += 1
184 | lab >>= 3
185 | return palette
186 |
187 | def get_palette_cityscapes(self, n):
188 | palette = [0] * (n * 3)
189 | from cityscapesscripts.helpers.labels import labels
190 | trainId2color = {label.trainId: label.color for label in labels if (label.trainId != 255 and label.trainId != -1)}
191 | for trainId, color in trainId2color.items():
192 | palette[trainId*3] = color[0]
193 | palette[trainId*3 + 1] = color[1]
194 | palette[trainId*3 + 2] = color[2]
195 |
196 | return palette
197 |
198 |
199 | def save_pred(self, preds, sv_path, name):
200 |
201 | palette = self.get_palette_cityscapes(256)
202 |
203 | preds = preds.cpu().numpy().copy()
204 | preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)
205 | for i in range(preds.shape[0]):
206 | pred = preds[i]
207 | save_img = Image.fromarray(pred)
208 | save_img.putpalette(palette)
209 | save_img.save(os.path.join(sv_path, name[i]+'.png'))
210 |
211 |
212 | def save_ts(self, t, sv_path, name):
213 | palette = self.get_palette(256)
214 | t = t.cpu().numpy().copy()
215 | for i in range(preds.shape[0]):
216 | pred = self.convert_label(preds[i], inverse=True)
217 | save_img = Image.fromarray(pred)
218 | save_img.putpalette(palette)
219 | save_img.save(os.path.join(sv_path, name[i]+'.png'))
220 |
221 |
222 |
223 |
--------------------------------------------------------------------------------
/lib/models/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import models.model_anytime
--------------------------------------------------------------------------------
/lib/models/conv_mask.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import pdb, time
4 | from tqdm import tqdm
5 | import numpy as np
6 | import torch.nn.functional as F
7 |
8 | class conv_mask_uniform(nn.Conv2d):
9 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, p=0.5, interpolate='none'):
10 | super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
11 | self.mask = None
12 | self.mask_built = False
13 | self.p = p
14 |
15 | self.interpolate = interpolate
16 | self.r = 7
17 | self.padding_interpolate = 3
18 |
19 | self.Lambda = nn.Parameter(torch.tensor(3.0))
20 | square_dis = np.zeros((self.r, self.r))
21 | center_point = (square_dis.shape[0]//2, square_dis.shape[1]//2)
22 |
23 | for i in range(square_dis.shape[0]):
24 | for j in range(square_dis.shape[1]):
25 | square_dis[i][j] = (i - center_point[0])**2 + (j - center_point[1])**2
26 |
27 | square_dis[center_point[0]][center_point[1]] = 100000.0
28 |
29 | self.square_dis = nn.Parameter(torch.Tensor(square_dis), requires_grad=False)
30 |
31 | def build_mask(self, x):
32 | mask_p = x.new(x.shape[2:]).fill_(self.p)
33 | mask = torch.bernoulli(mask_p)
34 | self.mask = mask[None, None, :, :].float()
35 | self.mask_built = True
36 |
37 | if self.in_channels == 3:
38 | print('Mask sum:', torch.sum(self.mask))
39 |
40 | def build_mask_random(self, x):
41 | mask_p = x.new(size=(x.shape[0], *x.shape[2:])).fill_(self.p)
42 | mask = torch.bernoulli(mask_p)
43 | self.mask = mask[:, None, :, :].float()
44 | self.mask_built = True
45 |
46 | def set_mask(self, mask):
47 | self.mask = mask[:, None, :, :]
48 | self.mask_built = True
49 |
50 | def forward(self, x):
51 | y = super().forward(x)
52 | self.out_h, self.out_w = y.size(-2), y.size(-1)
53 | if not self.mask_built:
54 | self.build_mask_random(y)
55 |
56 | kernel = (-(self.Lambda**2) * self.square_dis.detach()).exp()
57 | kernel = kernel / (kernel.sum() + 10**(-5))
58 | kernel = kernel.expand((self.out_channels, 1, kernel.size(0), kernel.size(1)))
59 | interpolated = F.conv2d(y * self.mask, kernel, stride=1, padding=self.padding_interpolate, groups=self.out_channels)
60 |
61 | out = y * self.mask + interpolated * (1 - self.mask)
62 | self.mask_built = False
63 |
64 | return out
65 |
66 | if __name__ == '__main__':
67 | a = Smooth(n_channels=10, kernel_size=3, padding=1)
68 |
69 |
--------------------------------------------------------------------------------
/lib/models/model_anytime.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import logging
7 | import functools
8 |
9 | import numpy as np
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch._utils
14 | import torch.nn.functional as F
15 |
16 | from utils.utils import AverageMeter
17 | import pdb, time
18 | from .conv_mask import conv_mask_uniform
19 | from functools import partial
20 |
21 | from utils.utils import get_rank
22 |
23 | BatchNorm2d = nn.BatchNorm2d
24 | BN_MOMENTUM = 0.01
25 | logger = logging.getLogger(__name__)
26 |
27 | def conv3x3(in_planes, out_planes, stride=1):
28 | return used_conv(in_planes, out_planes, kernel_size=3, stride=stride,
29 | padding=1, bias=False)
30 |
31 |
32 | class BasicBlock(nn.Module):
33 | expansion = 1
34 |
35 | def __init__(self, inplanes, planes, stride=1, downsample=None):
36 | super(BasicBlock, self).__init__()
37 | self.conv1 = conv3x3(inplanes, planes, stride)
38 | self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
39 | self.relu = nn.ReLU(inplace=True)
40 | self.conv2 = conv3x3(planes, planes)
41 | self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
42 | self.downsample = downsample
43 | self.stride = stride
44 |
45 | def forward(self, x):
46 | residual = x
47 |
48 | out = self.conv1(x)
49 | out = self.bn1(out)
50 | out = self.relu(out)
51 |
52 | out = self.conv2(out)
53 | out = self.bn2(out)
54 |
55 | if self.downsample is not None:
56 | residual = self.downsample(x)
57 |
58 | out += residual
59 | out = self.relu(out)
60 |
61 | return out
62 |
63 |
64 | class Bottleneck(nn.Module):
65 | expansion = 4
66 |
67 | def __init__(self, inplanes, planes, stride=1, downsample=None):
68 | super(Bottleneck, self).__init__()
69 | self.conv1 = used_conv(inplanes, planes, kernel_size=1, bias=False)
70 | self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
71 | self.conv2 = used_conv(planes, planes, kernel_size=3, stride=stride,
72 | padding=1, bias=False)
73 | self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
74 | self.conv3 = used_conv(planes, planes * self.expansion, kernel_size=1,
75 | bias=False)
76 | self.bn3 = BatchNorm2d(planes * self.expansion,
77 | momentum=BN_MOMENTUM)
78 | self.relu = nn.ReLU(inplace=True)
79 | self.downsample = downsample
80 | self.stride = stride
81 |
82 | def forward(self, x):
83 | residual = x
84 |
85 | out = self.conv1(x)
86 | out = self.bn1(out)
87 | out = self.relu(out)
88 |
89 | out = self.conv2(out)
90 | out = self.bn2(out)
91 | out = self.relu(out)
92 |
93 | out = self.conv3(out)
94 | out = self.bn3(out)
95 |
96 | if self.downsample is not None:
97 | residual = self.downsample(x)
98 |
99 | out += residual
100 | out = self.relu(out)
101 |
102 | return out
103 |
104 |
105 | class HighResolutionModule(nn.Module):
106 | def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
107 | num_channels, fuse_method, multi_scale_output=True):
108 | super(HighResolutionModule, self).__init__()
109 | self._check_branches(
110 | num_branches, blocks, num_blocks, num_inchannels, num_channels)
111 |
112 | self.num_inchannels = num_inchannels
113 | self.fuse_method = fuse_method
114 | self.num_branches = num_branches
115 |
116 | self.multi_scale_output = multi_scale_output
117 |
118 | self.branches = self._make_branches(
119 | num_branches, blocks, num_blocks, num_channels)
120 | self.fuse_layers = self._make_fuse_layers()
121 | self.relu = nn.ReLU(inplace=True)
122 |
123 | def _check_branches(self, num_branches, blocks, num_blocks,
124 | num_inchannels, num_channels):
125 | if num_branches != len(num_blocks):
126 | error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
127 | num_branches, len(num_blocks))
128 | logger.error(error_msg)
129 | raise ValueError(error_msg)
130 |
131 | if num_branches != len(num_channels):
132 | error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
133 | num_branches, len(num_channels))
134 | logger.error(error_msg)
135 | raise ValueError(error_msg)
136 |
137 | if num_branches != len(num_inchannels):
138 | error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
139 | num_branches, len(num_inchannels))
140 | logger.error(error_msg)
141 | raise ValueError(error_msg)
142 |
143 | def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
144 | stride=1):
145 | downsample = None
146 | if stride != 1 or \
147 | self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
148 | downsample = nn.Sequential(
149 | used_conv(self.num_inchannels[branch_index],
150 | num_channels[branch_index] * block.expansion,
151 | kernel_size=1, stride=stride, bias=False),
152 | BatchNorm2d(num_channels[branch_index] * block.expansion,
153 | momentum=BN_MOMENTUM),
154 | )
155 |
156 | layers = []
157 | layers.append(block(self.num_inchannels[branch_index],
158 | num_channels[branch_index], stride, downsample))
159 | self.num_inchannels[branch_index] = \
160 | num_channels[branch_index] * block.expansion
161 | for i in range(1, num_blocks[branch_index]):
162 | layers.append(block(self.num_inchannels[branch_index],
163 | num_channels[branch_index]))
164 |
165 | return nn.Sequential(*layers)
166 |
167 | def _make_branches(self, num_branches, block, num_blocks, num_channels):
168 | branches = []
169 |
170 | for i in range(num_branches):
171 | branches.append(
172 | self._make_one_branch(i, block, num_blocks, num_channels))
173 |
174 | return nn.ModuleList(branches)
175 |
176 | def _make_fuse_layers(self):
177 | if self.num_branches == 1:
178 | return None
179 |
180 | num_branches = self.num_branches
181 | num_inchannels = self.num_inchannels
182 | fuse_layers = []
183 | for i in range(num_branches if self.multi_scale_output else 1):
184 | fuse_layer = []
185 | for j in range(num_branches):
186 | if j > i:
187 | fuse_layer.append(nn.Sequential(
188 | used_conv(num_inchannels[j],
189 | num_inchannels[i],
190 | 1,
191 | 1,
192 | 0,
193 | bias=False),
194 | BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
195 | elif j == i:
196 | fuse_layer.append(None)
197 | else:
198 | conv3x3s = []
199 | for k in range(i-j):
200 | if k == i - j - 1:
201 | num_outchannels_conv3x3 = num_inchannels[i]
202 | conv3x3s.append(nn.Sequential(
203 | used_conv(num_inchannels[j],
204 | num_outchannels_conv3x3,
205 | 3, 2, 1, bias=False),
206 | BatchNorm2d(num_outchannels_conv3x3,
207 | momentum=BN_MOMENTUM)))
208 | else:
209 | num_outchannels_conv3x3 = num_inchannels[j]
210 | conv3x3s.append(nn.Sequential(
211 | used_conv(num_inchannels[j],
212 | num_outchannels_conv3x3,
213 | 3, 2, 1, bias=False),
214 | BatchNorm2d(num_outchannels_conv3x3,
215 | momentum=BN_MOMENTUM),
216 | nn.ReLU(inplace=True)))
217 | fuse_layer.append(nn.Sequential(*conv3x3s))
218 | fuse_layers.append(nn.ModuleList(fuse_layer))
219 |
220 | return nn.ModuleList(fuse_layers)
221 |
222 | def get_num_inchannels(self):
223 | return self.num_inchannels
224 |
225 | def forward(self, x):
226 | if self.num_branches == 1:
227 | return [self.branches[0](x[0])]
228 |
229 | for i in range(self.num_branches):
230 | x[i] = self.branches[i](x[i])
231 |
232 | x_fuse = []
233 | for i in range(len(self.fuse_layers)):
234 | y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
235 | for j in range(1, self.num_branches):
236 | if i == j:
237 | y = y + x[j]
238 | elif j > i:
239 | width_output = x[i].shape[-1]
240 | height_output = x[i].shape[-2]
241 | y = y + F.interpolate(
242 | self.fuse_layers[i][j](x[j]),
243 | size=[height_output, width_output],
244 | mode='bilinear')
245 | else:
246 | y = y + self.fuse_layers[i][j](x[j])
247 | x_fuse.append(self.relu(y))
248 |
249 | return x_fuse
250 |
251 |
252 | blocks_dict = {
253 | 'BASIC': BasicBlock,
254 | 'BOTTLENECK': Bottleneck
255 | }
256 |
257 |
258 |
259 | class HighResolutionNet(nn.Module):
260 | def _make_transition_layer(
261 | self, num_channels_pre_layer, num_channels_cur_layer):
262 | num_branches_cur = len(num_channels_cur_layer)
263 | num_branches_pre = len(num_channels_pre_layer)
264 |
265 | transition_layers = []
266 | for i in range(num_branches_cur):
267 | if i < num_branches_pre:
268 | if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
269 | transition_layers.append(nn.Sequential(
270 | used_conv(num_channels_pre_layer[i],
271 | num_channels_cur_layer[i],
272 | 3,
273 | 1,
274 | 1,
275 | bias=False),
276 | BatchNorm2d(
277 | num_channels_cur_layer[i], momentum=BN_MOMENTUM),
278 | nn.ReLU(inplace=True)))
279 | else:
280 | transition_layers.append(None)
281 | else:
282 | conv3x3s = []
283 | for j in range(i+1-num_branches_pre):
284 | inchannels = num_channels_pre_layer[-1]
285 | outchannels = num_channels_cur_layer[i] \
286 | if j == i-num_branches_pre else inchannels
287 | conv3x3s.append(nn.Sequential(
288 | used_conv(
289 | inchannels, outchannels, 3, 2, 1, bias=False),
290 | BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
291 | nn.ReLU(inplace=True)))
292 | transition_layers.append(nn.Sequential(*conv3x3s))
293 |
294 | return nn.ModuleList(transition_layers)
295 |
296 | def _make_layer(self, block, inplanes, planes, blocks, stride=1):
297 | downsample = None
298 | if stride != 1 or inplanes != planes * block.expansion:
299 | downsample = nn.Sequential(
300 | used_conv(inplanes, planes * block.expansion,
301 | kernel_size=1, stride=stride, bias=False),
302 | BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
303 | )
304 |
305 | layers = []
306 | layers.append(block(inplanes, planes, stride, downsample))
307 | inplanes = planes * block.expansion
308 | for i in range(1, blocks):
309 | layers.append(block(inplanes, planes))
310 |
311 | return nn.Sequential(*layers)
312 |
313 | def _make_stage(self, layer_config, num_inchannels,
314 | multi_scale_output=True):
315 | num_modules = layer_config['NUM_MODULES']
316 | num_branches = layer_config['NUM_BRANCHES']
317 | num_blocks = layer_config['NUM_BLOCKS']
318 | num_channels = layer_config['NUM_CHANNELS']
319 | block = blocks_dict[layer_config['BLOCK']]
320 | fuse_method = layer_config['FUSE_METHOD']
321 |
322 | modules = []
323 | for i in range(num_modules):
324 | if not multi_scale_output and i == num_modules - 1:
325 | reset_multi_scale_output = False
326 | else:
327 | reset_multi_scale_output = True
328 | modules.append(
329 | HighResolutionModule(num_branches,
330 | block,
331 | num_blocks,
332 | num_inchannels,
333 | num_channels,
334 | fuse_method,
335 | reset_multi_scale_output)
336 | )
337 | num_inchannels = modules[-1].get_num_inchannels()
338 |
339 | return nn.Sequential(*modules), num_inchannels
340 |
341 | def __init__(self, config, **kwargs):
342 |
343 |
344 | super(HighResolutionNet, self).__init__()
345 | extra = config.MODEL.EXTRA
346 | self.extra = extra
347 | self.mask_cfg = config.MASK
348 |
349 | global mask_conv, mask_conv_no_interpolate
350 | mask_conv = partial(conv_mask_uniform, p=self.mask_cfg.P, interpolate=self.mask_cfg.INTERPOLATION)
351 | mask_conv_no_interpolate = partial(conv_mask_uniform, p=self.mask_cfg.P, interpolate='none')
352 | global used_conv
353 | used_conv = nn.Conv2d
354 |
355 | self.num_exits = len(extra.EE_WEIGHTS)
356 | self.num_classes = config.DATASET.NUM_CLASSES
357 | if 'profiling_cpu' in kwargs or 'profiling_gpu' in kwargs:
358 | self.profiling_meters = [AverageMeter() for i in range(self.num_exits)]
359 | self.profiling_gpu = 'profiling_gpu' in kwargs
360 | self.profiling_cpu = 'profiling_cpu' in kwargs
361 | self.forward_count = 0
362 | else:
363 | self.profiling_gpu, self.profiling_cpu = False, False
364 |
365 | self.conv1 = used_conv(3, 64, kernel_size=3, stride=2, padding=1,
366 | bias=False)
367 | self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM)
368 | self.conv2 = used_conv(64, 64, kernel_size=3, stride=2, padding=1,
369 | bias=False)
370 | self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM)
371 | self.relu = nn.ReLU(inplace=True)
372 |
373 | self.stage1_cfg = extra['STAGE1']
374 | num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
375 | block = blocks_dict[self.stage1_cfg['BLOCK']]
376 | num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
377 | self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
378 | stage1_out_channel = block.expansion*num_channels
379 | self.exit1 = self.get_exit_layer(stage1_out_channel, config, exit_number=1)
380 |
381 | if self.mask_cfg.USE:
382 | used_conv = mask_conv
383 | else:
384 | used_conv = nn.Conv2d
385 |
386 | self.stage2_cfg = extra['STAGE2']
387 | num_channels = self.stage2_cfg['NUM_CHANNELS']
388 | block = blocks_dict[self.stage2_cfg['BLOCK']]
389 | num_channels = [
390 | num_channels[i] * block.expansion for i in range(len(num_channels))]
391 | self.transition1 = self._make_transition_layer(
392 | [stage1_out_channel], num_channels)
393 | self.stage2, pre_stage_channels = self._make_stage(
394 | self.stage2_cfg, num_channels)
395 | self.exit2 = self.get_exit_layer(np.int(np.sum(pre_stage_channels)), config, exit_number=2)
396 |
397 | if self.mask_cfg.USE:
398 | used_conv = mask_conv
399 | else:
400 | used_conv = nn.Conv2d
401 |
402 |
403 | self.stage3_cfg = extra['STAGE3']
404 | num_channels = self.stage3_cfg['NUM_CHANNELS']
405 | block = blocks_dict[self.stage3_cfg['BLOCK']]
406 | num_channels = [
407 | num_channels[i] * block.expansion for i in range(len(num_channels))]
408 | self.transition2 = self._make_transition_layer(
409 | pre_stage_channels, num_channels)
410 | self.stage3, pre_stage_channels = self._make_stage(
411 | self.stage3_cfg, num_channels)
412 | self.exit3 = self.get_exit_layer(np.int(np.sum(pre_stage_channels)), config, exit_number=3)
413 |
414 | self.stage4_cfg = extra['STAGE4']
415 | num_channels = self.stage4_cfg['NUM_CHANNELS']
416 | block = blocks_dict[self.stage4_cfg['BLOCK']]
417 | num_channels = [
418 | num_channels[i] * block.expansion for i in range(len(num_channels))]
419 | self.transition3 = self._make_transition_layer(
420 | pre_stage_channels, num_channels)
421 | self.stage4, pre_stage_channels = self._make_stage(
422 | self.stage4_cfg, num_channels, multi_scale_output=True)
423 |
424 | last_inp_channels = np.int(np.sum(pre_stage_channels))
425 | self.last_layer = self.get_exit_layer(last_inp_channels, config, last=True)
426 |
427 | print(sum(p.numel() for p in self.parameters() if p.requires_grad))
428 | print(sum(p.numel() for p in self.parameters()))
429 |
430 |
431 |
432 | def profile(self, out, index):
433 | if not (self.profiling_cpu or self.profiling_gpu):
434 | return
435 | self.forward_count += 1
436 | print(self.forward_count)
437 | start_count = 25 * 4
438 | if self.forward_count < start_count:
439 | return
440 |
441 | if self.profiling_cpu:
442 | self.profiling_meters[index].update(time.time() - self.start)
443 | elif self.profiling_gpu:
444 | tmp_out = out.cpu()
445 | torch.cuda.synchronize()
446 | self.profiling_meters[index].update(time.time() - self.start)
447 | else:
448 | return
449 | if index == self.num_exits - 1 and (self.forward_count > start_count + 10):
450 | times = [self.profiling_meters[i].average() for i in range(self.num_exits)]
451 | times.append(np.mean(times))
452 | print('\t'.join(['{:.3f}'.format(x) for x in times]))
453 |
454 | def get_points_from_confs(self, confs, ratio):
455 | bs, h, w = confs.size(0), confs.size(2), confs.size(3)
456 | idx = torch.arange(h * w, device=confs.device)
457 | h_pos = idx // w
458 | w_pos = idx % w
459 | point_coords_int = torch.cat((h_pos.unsqueeze(1), w_pos.unsqueeze(1)), dim=1)
460 | point_coords_int = point_coords_int.unsqueeze(0).repeat(bs, 1, 1)
461 | num_sampled = point_coords_int.size(1)
462 |
463 | num_certain_points = int(ratio * h * w)
464 | point_certainties = confs.view(bs, 1, -1)
465 | values, idx = torch.topk(point_certainties[:, 0, :], k=num_certain_points, dim=1)
466 | shift = num_sampled * torch.arange(bs, dtype=torch.long, device=confs.device)
467 | idx += shift[:, None]
468 | point_coords_selected_int = point_coords_int.view(-1, 2)[idx.view(-1), :].view(
469 | bs, num_certain_points, 2
470 | )
471 | point_coords_selected_frac = torch.cat(( (point_coords_selected_int[:, :, 0:1] + 0.5)/float(h), (point_coords_selected_int[:, :, 1:2] + 0.5)/float(w)), dim=2)
472 | return point_coords_selected_int, point_coords_selected_frac
473 |
474 | def get_resized_mask_from_logits(self, logits, h, w,criterion):
475 | if criterion == 'conf_thre':
476 | resized_logits = F.interpolate(logits, size=(h, w))
477 | resized_probs = F.softmax(resized_logits, dim=1)
478 | resized_confs, _ = resized_probs.max(dim=1, keepdim=True)
479 | mask = (resized_confs <= self.mask_cfg.CONF_THRE).float().view(logits.size(0), h, w)
480 | elif criterion == 'entropy_thre':
481 | resized_logits = F.interpolate(logits, size=(h, w))
482 | resized_probs = F.softmax(resized_logits, dim=1)
483 | resized_confs = torch.sum( - resized_probs * torch.log(resized_probs), dim=1, keepdim=True) #
484 | mask = (resized_confs >= self.mask_cfg.ENTROPY_THRE).float().view(logits.size(0), h, w)
485 | return mask
486 |
487 | def generate_grid_priors(self):
488 | if hasattr(self, 'mask_grid_prior_dict') and len(self.mask_grid_prior_dict) > 0:
489 | return
490 | self.mask_grid_prior_dict = {}
491 |
492 | for m in self.modules():
493 | if isinstance(m, conv_mask_uniform):
494 | try:
495 | h,w = m.out_h, m.out_w
496 | except:
497 | logger.info("First forwarding, collecting output size, quit generating grid priors")
498 | break
499 |
500 | if (h,w) in self.mask_grid_prior_dict:
501 | continue
502 | logger.info(f"generating grid priors for size {(h,w)}")
503 | res = torch.zeros((h, w), device=m.weight.device)
504 | stride = self.mask_cfg.GRID_STRIDE
505 | start = (stride - 1) // 2
506 |
507 | for i in range(start, res.size(0), stride):
508 | for j in range(start, res.size(1), stride):
509 | res[i][j] = 1.
510 |
511 | self.mask_grid_prior_dict[(h, w)] = res
512 |
513 | def set_masks(self, logits):
514 | self.mask_dict = {}
515 | for m in self.modules():
516 | if isinstance(m, conv_mask_uniform):
517 | try:
518 | h,w = m.out_h, m.out_w
519 | except:
520 | logger.info("First forwarding, collecting output size, quit setting masks")
521 | break
522 |
523 | if (h,w) in self.mask_dict:
524 | m.set_mask(self.mask_dict[(h,w)])
525 | else:
526 | self.mask_dict[(h,w)] = self.get_resized_mask_from_logits(logits, h, w, criterion=self.mask_cfg.CRIT)
527 |
528 | if self.mask_cfg.GRID_PRIOR:
529 | self.mask_dict[(h,w)] = torch.max(self.mask_dict[(h,w)], self.mask_grid_prior_dict[(h, w)])
530 | m.set_mask(self.mask_dict[(h,w)])
531 |
532 |
533 | def set_part_masks(self, logits, ref_name, masked_modules):
534 | start = time.time()
535 | self.part_mask_dicts[ref_name] = {}
536 | for module in masked_modules:
537 | for m in module.modules():
538 | if isinstance(m, conv_mask_uniform):
539 | try:
540 | h,w = m.out_h, m.out_w
541 | except:
542 | logger.info("First forwarding, collecting output size, quit setting masks")
543 | break
544 | if (h,w) in self.part_mask_dicts[ref_name]:
545 | m.set_mask(self.part_mask_dicts[ref_name][(h,w)])
546 | else:
547 | self.part_mask_dicts[ref_name][(h,w)] = self.get_resized_mask_from_logits(logits, h, w, criterion=self.mask_cfg.CRIT)
548 | m.set_mask(self.part_mask_dicts[ref_name][(h,w)])
549 |
550 | def forward(self, x):
551 | self.part_mask_dicts = {}
552 |
553 | if self.profiling_gpu:
554 | torch.cuda.synchronize()
555 | if self.profiling_cpu or self.profiling_gpu:
556 | self.start = time.time()
557 |
558 | x = self.conv1(x)
559 | x = self.bn1(x)
560 | x = self.relu(x)
561 | x = self.conv2(x)
562 | x = self.bn2(x)
563 | x = self.relu(x)
564 | x = self.layer1(x)
565 | out1_feat = self.get_exit_input([x], detach=self.extra.EARLY_DETACH)
566 | out1 = self.exit1(out1_feat) # logits of exit 1
567 | out_size = (out1.size(-2), out1.size(-1))
568 |
569 | # Set mask for all conv_mask modules between exit 1 and exit 2
570 | if self.mask_cfg.USE:
571 | self.set_part_masks(out1, 'out1', [self.transition1, self.stage2, self.exit2])
572 | if hasattr(self, "stop1"):
573 | return out1
574 |
575 | x_list = []
576 | for i in range(self.stage2_cfg['NUM_BRANCHES']):
577 | if self.transition1[i] is not None:
578 | x_list.append(self.transition1[i](x))
579 | else:
580 | x_list.append(x)
581 | y_list = self.stage2(x_list)
582 | out2_feat = self.get_exit_input(y_list, detach=self.extra.EARLY_DETACH)
583 | out2 = self.exit2(out2_feat)
584 |
585 | if self.mask_cfg.USE:
586 | # Compute logits, aggregate results from the previous exit
587 | if self.mask_cfg.AGGR == 'copy' and len(self.part_mask_dicts['out1']) > 0:
588 | result_mask = self.part_mask_dicts['out1'][out_size][:, None, :, :]
589 | out2 = out1 * (1-result_mask) + out2 * result_mask
590 | # Set mask for all conv_mask modules between exit 2 and exit 3
591 | self.set_part_masks(out2, 'out2', [self.transition2, self.stage3, self.exit3])
592 | if hasattr(self, "stop2"):
593 | return out2
594 |
595 | x_list = []
596 | for i in range(self.stage3_cfg['NUM_BRANCHES']):
597 | if self.transition2[i] is not None:
598 | x_list.append(self.transition2[i](y_list[-1]))
599 | else:
600 | x_list.append(y_list[i])
601 | y_list = self.stage3(x_list)
602 | out3_feat = self.get_exit_input(y_list, detach=self.extra.EARLY_DETACH)
603 | out3 = self.exit3(out3_feat)
604 |
605 | if self.mask_cfg.USE:
606 | # Compute logits, aggregate results from the previous exit
607 | if self.mask_cfg.AGGR == 'copy' and len(self.part_mask_dicts['out2']) > 0:
608 | result_mask = self.part_mask_dicts['out2'][out_size][:, None, :, :]
609 | out3 = out2 * (1-result_mask) + out3 * result_mask
610 | # Set mask for all conv_mask module between exit 3 and exit 4
611 | self.set_part_masks(out3, 'out3', [self.transition3, self.stage4, self.last_layer])
612 | if hasattr(self, "stop3"):
613 | return out3
614 |
615 | x_list = []
616 | for i in range(self.stage4_cfg['NUM_BRANCHES']):
617 | if self.transition3[i] is not None:
618 | x_list.append(self.transition3[i](y_list[-1]))
619 | else:
620 | x_list.append(y_list[i])
621 |
622 | y_list = self.stage4(x_list)
623 | out4_feat = self.get_exit_input(y_list, detach=False)
624 | out4 = self.last_layer(out4_feat)
625 |
626 | if self.mask_cfg.USE:
627 | if self.mask_cfg.AGGR == 'copy' and len(self.part_mask_dicts['out3']) > 0:
628 | result_mask = self.part_mask_dicts['out3'][out_size][:, None, :, :]
629 | out4 = out3 * (1-result_mask) + out4 * result_mask
630 |
631 | self.profile(out4, 3)
632 | if hasattr(self, "stop4"):
633 | return out4
634 |
635 | outs = [out1, out2, out3, out4]
636 |
637 | return outs
638 |
639 |
640 | def get_exit_layer(self, num_channels, config, last=False, exit_number=0):
641 | print(f'EXIT num_channels:{num_channels}')
642 | extra = config.MODEL.EXTRA
643 | layer_type = config.EXIT.TYPE if (not last) else 'original'
644 |
645 | inter_channel = int(num_channels)
646 |
647 | if layer_type == 'flex':
648 | assert exit_number in [1,2,3]
649 | type_map = {1: 'downup_pool_1x1_inter_triple', 2: 'downup_pool_1x1_inter_double', 3: 'downup_pool_1x1_inter'}
650 | layer_type = type_map[exit_number]
651 | inter_channel = config.EXIT.INTER_CHANNEL
652 |
653 | if self.mask_cfg.USE:
654 | exit_conv = used_conv
655 | else:
656 | exit_conv = nn.Conv2d
657 |
658 | norm_layer = BatchNorm2d(num_channels, momentum=BN_MOMENTUM)
659 |
660 | if layer_type == 'original':
661 | exit_layer = [
662 | exit_conv(
663 | in_channels=num_channels,
664 | out_channels=num_channels,
665 | kernel_size=1,
666 | stride=1,
667 | padding=0,
668 | bias=True),
669 |
670 | norm_layer,
671 | nn.ReLU(inplace=True),
672 | exit_conv(
673 | in_channels=num_channels,
674 | out_channels=config.DATASET.NUM_CLASSES,
675 | kernel_size=config.EXIT.FINAL_CONV_KERNEL,
676 | stride=1,
677 | padding=1 if config.EXIT.FINAL_CONV_KERNEL == 3 else 0,
678 | bias=True,
679 | )
680 | ]
681 |
682 | elif layer_type == 'downup_pool_1x1_inter':
683 | exit_layer = [
684 | nn.AvgPool2d(2, 2),
685 | exit_conv(
686 | in_channels=num_channels,
687 | out_channels=inter_channel,
688 | kernel_size=1,
689 | stride=1,
690 | padding=0,
691 | bias=True),
692 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM),
693 | nn.ReLU(inplace=True),
694 | nn.Upsample(scale_factor=2, mode='bilinear'),
695 | exit_conv(
696 | in_channels=inter_channel,
697 | out_channels=config.DATASET.NUM_CLASSES,
698 | kernel_size=config.EXIT.FINAL_CONV_KERNEL,
699 | stride=1,
700 | padding=1 if config.EXIT.FINAL_CONV_KERNEL == 3 else 0,
701 | bias=True,
702 | )
703 | ]
704 |
705 |
706 |
707 | elif layer_type == 'downup_pool_1x1_inter_double':
708 | exit_layer = [
709 | nn.AvgPool2d(2, 2),
710 | exit_conv(
711 | in_channels=num_channels,
712 | out_channels=inter_channel,
713 | kernel_size=1,
714 | stride=1,
715 | padding=0,
716 | bias=True),
717 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM),
718 | nn.ReLU(inplace=True),
719 |
720 | nn.AvgPool2d(2, 2),
721 | exit_conv(
722 | in_channels=inter_channel,
723 | out_channels=inter_channel,
724 | kernel_size=1,
725 | stride=1,
726 | padding=0,
727 | bias=True
728 | ),
729 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM),
730 | nn.ReLU(inplace=True),
731 |
732 | nn.Upsample(scale_factor=2, mode='bilinear'),
733 | exit_conv(
734 | in_channels=inter_channel,
735 | out_channels=inter_channel,
736 | kernel_size=1,
737 | stride=1,
738 | padding=0,
739 | bias=True
740 | ),
741 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM),
742 | nn.ReLU(inplace=True),
743 |
744 | nn.Upsample(scale_factor=2, mode='bilinear'),
745 | exit_conv(
746 | in_channels=inter_channel,
747 | out_channels=config.DATASET.NUM_CLASSES,
748 | kernel_size=config.EXIT.FINAL_CONV_KERNEL,
749 | stride=1,
750 | padding=1 if config.EXIT.FINAL_CONV_KERNEL == 3 else 0,
751 | bias=True,
752 | )
753 | ]
754 |
755 |
756 | elif layer_type == 'downup_pool_1x1_inter_triple':
757 | exit_layer = [
758 | nn.AvgPool2d(2, 2),
759 | exit_conv(
760 | in_channels=num_channels,
761 | out_channels=inter_channel,
762 | kernel_size=1,
763 | stride=1,
764 | padding=0,
765 | bias=True),
766 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM),
767 | nn.ReLU(inplace=True),
768 |
769 | nn.AvgPool2d(2, 2),
770 | exit_conv(
771 | in_channels=inter_channel,
772 | out_channels=inter_channel,
773 | kernel_size=1,
774 | stride=1,
775 | padding=0,
776 | bias=True
777 | ),
778 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM),
779 | nn.ReLU(inplace=True),
780 |
781 | nn.AvgPool2d(2, 2),
782 | exit_conv(
783 | in_channels=inter_channel,
784 | out_channels=inter_channel,
785 | kernel_size=1,
786 | stride=1,
787 | padding=0,
788 | bias=True
789 | ),
790 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM),
791 | nn.ReLU(inplace=True),
792 |
793 | nn.Upsample(scale_factor=2, mode='bilinear'),
794 | exit_conv(
795 | in_channels=inter_channel,
796 | out_channels=inter_channel,
797 | kernel_size=1,
798 | stride=1,
799 | padding=0,
800 | bias=True
801 | ),
802 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM),
803 | nn.ReLU(inplace=True),
804 |
805 | nn.Upsample(scale_factor=2, mode='bilinear'),
806 | exit_conv(
807 | in_channels=inter_channel,
808 | out_channels=inter_channel,
809 | kernel_size=1,
810 | stride=1,
811 | padding=0,
812 | bias=True
813 | ),
814 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM),
815 | nn.ReLU(inplace=True),
816 |
817 | nn.Upsample(scale_factor=2, mode='bilinear'),
818 | exit_conv(
819 | in_channels=inter_channel,
820 | out_channels=config.DATASET.NUM_CLASSES,
821 | kernel_size=config.EXIT.FINAL_CONV_KERNEL,
822 | stride=1,
823 | padding=1 if config.EXIT.FINAL_CONV_KERNEL == 3 else 0,
824 | bias=True,
825 | )
826 | ]
827 |
828 | exit_layer = nn.Sequential(*exit_layer)
829 |
830 | return exit_layer
831 |
832 | def get_exit_input(self, x, detach=True):
833 | interpolated_list = [x[0]]
834 | x0_h, x0_w = x[0].size(2), x[0].size(3)
835 |
836 | for i in range(1, len(x)):
837 | interpolated_list.append(F.upsample(x[i], size=(x0_h, x0_w), mode='bilinear'))
838 |
839 | ret = torch.cat(interpolated_list, 1)
840 |
841 | return ret.detach() if detach else ret
842 |
843 |
844 |
845 | def init_weights(self, pretrained='', load_stage=1):
846 | logger.info('=> init weights from normal distribution')
847 | for m in self.modules():
848 | if isinstance(m, nn.Conv2d):
849 | nn.init.normal_(m.weight, std=0.001)
850 | elif isinstance(m, nn.BatchNorm2d):
851 | nn.init.constant_(m.weight, 1)
852 | nn.init.constant_(m.bias, 0)
853 |
854 | if os.path.isfile(pretrained) and load_stage == 0:
855 | pretrained_dict = torch.load(pretrained)
856 | logger.info('=> loading pretrained model {}'.format(pretrained))
857 | model_dict = self.state_dict()
858 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
859 | elif os.path.isfile(pretrained) and load_stage == 1:
860 | pretrained_dict = torch.load(pretrained)
861 | logger.info('=> loading pretrained model {}'.format(pretrained))
862 | model_dict = self.state_dict()
863 | pretrained_dict = {k[len('model.'):]: v for k, v in pretrained_dict.items() if k[len('model.'):] in model_dict.keys()}
864 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith('exit')}
865 |
866 | elif os.path.isfile(pretrained) and load_stage == 2:
867 | pretrained_dict = torch.load(pretrained)
868 | logger.info('=> loading pretrained model {}'.format(pretrained))
869 | model_dict = self.state_dict()
870 | pretrained_dict = {k[len('model.'):]: v for k, v in pretrained_dict.items() if k[len('model.'):] in model_dict.keys()}
871 |
872 | logger.info('loading stage: {}, loading {} dict keys'.format(load_stage, len(pretrained_dict)))
873 | model_dict.update(pretrained_dict)
874 | self.load_state_dict(model_dict)
875 |
876 |
877 | class L2Norm(nn.Module):
878 | def __init__(self):
879 | super(L2Norm, self).__init__()
880 | def forward(self, x):
881 | return F.normalize(x, p=2, dim=0)
882 |
883 |
884 | class TemperatureScaling(nn.Module):
885 | def __init__(self, channel_wise, location_wise):
886 | super(TemperatureScaling, self).__init__()
887 |
888 | self.channel_wise = channel_wise
889 | self.location_wise = location_wise
890 | self.shift = 0.5413
891 |
892 | def forward(self, x):
893 | pass
894 |
895 | class TemperatureScalingFixed(TemperatureScaling):
896 | def __init__(self, channel_wise=False, location_wise=False, num_channels=0):
897 |
898 | super(TemperatureScalingFixed, self).__init__(channel_wise=channel_wise, location_wise=location_wise)
899 | self.num_channels = num_channels
900 |
901 | assert (not self.location_wise)
902 |
903 | if channel_wise:
904 | self.t_vector = nn.Parameter(torch.zeros(num_channels), requires_grad=True)
905 | else:
906 | self.t = nn.Parameter(torch.zeros(1), requires_grad=True)
907 |
908 | def forward(self, x):
909 |
910 | if self.channel_wise:
911 | positive_t_vector = F.softplus(self.t_vector + self.shift)
912 | out = x * positive_t_vector[None, :, None, None]
913 | else:
914 | positive_t = F.softplus(self.t + self.shift)
915 | out = x * positive_t
916 | return out
917 |
918 | class TemperatureScalingPredicted(TemperatureScaling):
919 | def __init__(self, channel_wise=False, location_wise=False, in_channels=0, layer_type='conv1'):
920 | super(TemperatureScalingPredicted, self).__init__(channel_wise=channel_wise, location_wise=location_wise)
921 | assert self.location_wise
922 |
923 | self.in_channels = in_channels
924 | self.layer_type = layer_type
925 |
926 | if self.layer_type == 'conv1':
927 | self.layer = used_conv(in_channels, 1, kernel_size=1, padding=0)
928 | elif self.layer_type == 'conv3':
929 | self.layer = used_conv(in_channels, 1, kernel_size=3, padding=1)
930 | elif self.layer_type == 'default_exit':
931 | self.layer = nn.Sequential(
932 | used_conv(
933 | in_channels=in_channels,
934 | out_channels=in_channels,
935 | kernel_size=1,
936 | stride=1,
937 | padding=0),
938 | BatchNorm2d(in_channels, momentum=BN_MOMENTUM),
939 | nn.ReLU(inplace=True),
940 | used_conv(
941 | in_channels=in_channels,
942 | out_channels=1,
943 | kernel_size=1,
944 | stride=1,
945 | padding=0)
946 | )
947 | else:
948 | raise NotImplementedError('TemperatureScalingPredicted layer type {} not implemented!'.format(self.layer_type))
949 |
950 | def forward(self, x):
951 | logits = x[0]
952 | features = x[1]
953 | self.t_map = self.layer(features) * 1.0
954 | self.positive_t_map = F.softplus(self.t_map + self.shift)
955 | return logits * self.positive_t_map
956 |
957 |
958 | def get_seg_model(cfg, **kwargs):
959 | model = HighResolutionNet(cfg, **kwargs)
960 | model.init_weights(cfg.MODEL.PRETRAINED, cfg.MODEL.LOAD_STAGE)
961 |
962 | return model
963 |
964 | if __name__ == '__main__':
965 | from config import config
966 | from config import update_config
967 | import argparse
968 | import torch.backends.cudnn as cudnn
969 |
970 | def parse_args():
971 | parser = argparse.ArgumentParser(description='Train segmentation network')
972 |
973 | parser.add_argument('--cfg',
974 | help='experiment configure file name',
975 | type=str, default='experiments/cityscapes/seg_hrnet_ee_0715_mask.yaml')
976 | parser.add_argument('opts',
977 | help="Modify config options using the command-line",
978 | default=None,
979 | nargs=argparse.REMAINDER)
980 | args = parser.parse_args()
981 | update_config(config, args)
982 | return args
983 |
984 | args = parse_args()
985 | cudnn.benchmark = config.CUDNN.BENCHMARK
986 | cudnn.deterministic = config.CUDNN.DETERMINISTIC
987 | cudnn.enabled = config.CUDNN.ENABLED
988 |
989 | model = eval('get_seg_model')(config)
990 | model = nn.DataParallel(model, device_ids=[0]).cuda()
991 |
992 | for i in range(20):
993 | print(i)
994 | dump_input = torch.rand(
995 | (1, 3, config.TRAIN.IMAGE_SIZE[1]//4, config.TRAIN.IMAGE_SIZE[0]//4)
996 | )
997 | out = model(dump_input)
998 |
999 | def count_parameters(model):
1000 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
1001 | print(count_parameters(model))
1002 |
1003 |
--------------------------------------------------------------------------------
/lib/models/sync_bn/__init__.py:
--------------------------------------------------------------------------------
1 | from .inplace_abn import bn
--------------------------------------------------------------------------------
/lib/models/sync_bn/inplace_abn/__init__.py:
--------------------------------------------------------------------------------
1 | from .bn import ABN, InPlaceABN, InPlaceABNSync
2 | from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE
3 |
--------------------------------------------------------------------------------
/lib/models/sync_bn/inplace_abn/bn.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as functional
5 |
6 | try:
7 | from queue import Queue
8 | except ImportError:
9 | from Queue import Queue
10 |
11 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
12 | sys.path.append(BASE_DIR)
13 | sys.path.append(os.path.join(BASE_DIR, '../src'))
14 | from functions import *
15 |
16 |
17 | class ABN(nn.Module):
18 | """Activated Batch Normalization
19 |
20 | This gathers a `BatchNorm2d` and an activation function in a single module
21 | """
22 |
23 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
24 | """Creates an Activated Batch Normalization module
25 |
26 | Parameters
27 | ----------
28 | num_features : int
29 | Number of feature channels in the input and output.
30 | eps : float
31 | Small constant to prevent numerical issues.
32 | momentum : float
33 | Momentum factor applied to compute running statistics as.
34 | affine : bool
35 | If `True` apply learned scale and shift transformation after normalization.
36 | activation : str
37 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
38 | slope : float
39 | Negative slope for the `leaky_relu` activation.
40 | """
41 | super(ABN, self).__init__()
42 | self.num_features = num_features
43 | self.affine = affine
44 | self.eps = eps
45 | self.momentum = momentum
46 | self.activation = activation
47 | self.slope = slope
48 | if self.affine:
49 | self.weight = nn.Parameter(torch.ones(num_features))
50 | self.bias = nn.Parameter(torch.zeros(num_features))
51 | else:
52 | self.register_parameter('weight', None)
53 | self.register_parameter('bias', None)
54 | self.register_buffer('running_mean', torch.zeros(num_features))
55 | self.register_buffer('running_var', torch.ones(num_features))
56 | self.reset_parameters()
57 |
58 | def reset_parameters(self):
59 | nn.init.constant_(self.running_mean, 0)
60 | nn.init.constant_(self.running_var, 1)
61 | if self.affine:
62 | nn.init.constant_(self.weight, 1)
63 | nn.init.constant_(self.bias, 0)
64 |
65 | def forward(self, x):
66 | x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
67 | self.training, self.momentum, self.eps)
68 |
69 | if self.activation == ACT_RELU:
70 | return functional.relu(x, inplace=True)
71 | elif self.activation == ACT_LEAKY_RELU:
72 | return functional.leaky_relu(x, negative_slope=self.slope, inplace=True)
73 | elif self.activation == ACT_ELU:
74 | return functional.elu(x, inplace=True)
75 | else:
76 | return x
77 |
78 | def __repr__(self):
79 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
80 | ' affine={affine}, activation={activation}'
81 | if self.activation == "leaky_relu":
82 | rep += ', slope={slope})'
83 | else:
84 | rep += ')'
85 | return rep.format(name=self.__class__.__name__, **self.__dict__)
86 |
87 |
88 | class InPlaceABN(ABN):
89 | """InPlace Activated Batch Normalization"""
90 |
91 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
92 | """Creates an InPlace Activated Batch Normalization module
93 |
94 | Parameters
95 | ----------
96 | num_features : int
97 | Number of feature channels in the input and output.
98 | eps : float
99 | Small constant to prevent numerical issues.
100 | momentum : float
101 | Momentum factor applied to compute running statistics as.
102 | affine : bool
103 | If `True` apply learned scale and shift transformation after normalization.
104 | activation : str
105 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
106 | slope : float
107 | Negative slope for the `leaky_relu` activation.
108 | """
109 | super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope)
110 |
111 | def forward(self, x):
112 | return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
113 | self.training, self.momentum, self.eps, self.activation, self.slope)
114 |
115 |
116 | class InPlaceABNSync(ABN):
117 | """InPlace Activated Batch Normalization with cross-GPU synchronization
118 |
119 | This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DataParallel`.
120 | """
121 |
122 | def __init__(self, num_features, devices=None, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu",
123 | slope=0.01):
124 | """Creates a synchronized, InPlace Activated Batch Normalization module
125 |
126 | Parameters
127 | ----------
128 | num_features : int
129 | Number of feature channels in the input and output.
130 | devices : list of int or None
131 | IDs of the GPUs that will run the replicas of this module.
132 | eps : float
133 | Small constant to prevent numerical issues.
134 | momentum : float
135 | Momentum factor applied to compute running statistics as.
136 | affine : bool
137 | If `True` apply learned scale and shift transformation after normalization.
138 | activation : str
139 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
140 | slope : float
141 | Negative slope for the `leaky_relu` activation.
142 | """
143 | super(InPlaceABNSync, self).__init__(num_features, eps, momentum, affine, activation, slope)
144 | self.devices = devices if devices else list(range(torch.cuda.device_count()))
145 |
146 | # Initialize queues
147 | self.worker_ids = self.devices[1:]
148 | self.master_queue = Queue(len(self.worker_ids))
149 | self.worker_queues = [Queue(1) for _ in self.worker_ids]
150 |
151 | def forward(self, x):
152 | if x.get_device() == self.devices[0]:
153 | # Master mode
154 | extra = {
155 | "is_master": True,
156 | "master_queue": self.master_queue,
157 | "worker_queues": self.worker_queues,
158 | "worker_ids": self.worker_ids
159 | }
160 | else:
161 | # Worker mode
162 | extra = {
163 | "is_master": False,
164 | "master_queue": self.master_queue,
165 | "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())]
166 | }
167 |
168 | return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var,
169 | extra, self.training, self.momentum, self.eps, self.activation, self.slope)
170 |
171 | def __repr__(self):
172 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
173 | ' affine={affine}, devices={devices}, activation={activation}'
174 | if self.activation == "leaky_relu":
175 | rep += ', slope={slope})'
176 | else:
177 | rep += ')'
178 | return rep.format(name=self.__class__.__name__, **self.__dict__)
179 |
--------------------------------------------------------------------------------
/lib/models/sync_bn/inplace_abn/functions.py:
--------------------------------------------------------------------------------
1 | from os import path
2 |
3 | import torch.autograd as autograd
4 | import torch.cuda.comm as comm
5 | from torch.autograd.function import once_differentiable
6 | from torch.utils.cpp_extension import load
7 |
8 | _src_path = path.join(path.dirname(path.abspath(__file__)), "src")
9 | _backend = load(name="inplace_abn",
10 | extra_cflags=["-O3"],
11 | sources=[path.join(_src_path, f) for f in [
12 | "inplace_abn.cpp",
13 | "inplace_abn_cpu.cpp",
14 | "inplace_abn_cuda.cu"
15 | ]],
16 | extra_cuda_cflags=["--expt-extended-lambda"])
17 |
18 | # Activation names
19 | ACT_RELU = "relu"
20 | ACT_LEAKY_RELU = "leaky_relu"
21 | ACT_ELU = "elu"
22 | ACT_NONE = "none"
23 |
24 |
25 | def _check(fn, *args, **kwargs):
26 | success = fn(*args, **kwargs)
27 | if not success:
28 | raise RuntimeError("CUDA Error encountered in {}".format(fn))
29 |
30 |
31 | def _broadcast_shape(x):
32 | out_size = []
33 | for i, s in enumerate(x.size()):
34 | if i != 1:
35 | out_size.append(1)
36 | else:
37 | out_size.append(s)
38 | return out_size
39 |
40 |
41 | def _reduce(x):
42 | if len(x.size()) == 2:
43 | return x.sum(dim=0)
44 | else:
45 | n, c = x.size()[0:2]
46 | return x.contiguous().view((n, c, -1)).sum(2).sum(0)
47 |
48 |
49 | def _count_samples(x):
50 | count = 1
51 | for i, s in enumerate(x.size()):
52 | if i != 1:
53 | count *= s
54 | return count
55 |
56 |
57 | def _act_forward(ctx, x):
58 | if ctx.activation == ACT_LEAKY_RELU:
59 | _backend.leaky_relu_forward(x, ctx.slope)
60 | elif ctx.activation == ACT_ELU:
61 | _backend.elu_forward(x)
62 | elif ctx.activation == ACT_NONE:
63 | pass
64 |
65 |
66 | def _act_backward(ctx, x, dx):
67 | if ctx.activation == ACT_LEAKY_RELU:
68 | _backend.leaky_relu_backward(x, dx, ctx.slope)
69 | elif ctx.activation == ACT_ELU:
70 | _backend.elu_backward(x, dx)
71 | elif ctx.activation == ACT_NONE:
72 | pass
73 |
74 |
75 | class InPlaceABN(autograd.Function):
76 | @staticmethod
77 | def forward(ctx, x, weight, bias, running_mean, running_var,
78 | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
79 | ctx.training = training
80 | ctx.momentum = momentum
81 | ctx.eps = eps
82 | ctx.activation = activation
83 | ctx.slope = slope
84 | ctx.affine = weight is not None and bias is not None
85 |
86 | count = _count_samples(x)
87 | x = x.contiguous()
88 | weight = weight.contiguous() if ctx.affine else x.new_empty(0)
89 | bias = bias.contiguous() if ctx.affine else x.new_empty(0)
90 |
91 | if ctx.training:
92 | mean, var = _backend.mean_var(x)
93 |
94 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
95 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))
96 |
97 | ctx.mark_dirty(x, running_mean, running_var)
98 | else:
99 | mean, var = running_mean.contiguous(), running_var.contiguous()
100 | ctx.mark_dirty(x)
101 |
102 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
103 | _act_forward(ctx, x)
104 |
105 | ctx.var = var
106 | ctx.save_for_backward(x, var, weight, bias)
107 | return x
108 |
109 | @staticmethod
110 | @once_differentiable
111 | def backward(ctx, dz):
112 | z, var, weight, bias = ctx.saved_tensors
113 | dz = dz.contiguous()
114 |
115 | _act_backward(ctx, z, dz)
116 |
117 | if ctx.training:
118 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
119 | else:
120 | edz = dz.new_zeros(dz.size(1))
121 | eydz = dz.new_zeros(dz.size(1))
122 |
123 | dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
124 | dweight = dweight if ctx.affine else None
125 | dbias = dbias if ctx.affine else None
126 |
127 | return dx, dweight, dbias, None, None, None, None, None, None, None
128 |
129 |
130 | class InPlaceABNSync(autograd.Function):
131 | @classmethod
132 | def forward(cls, ctx, x, weight, bias, running_mean, running_var,
133 | extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
134 | cls._parse_extra(ctx, extra)
135 | ctx.training = training
136 | ctx.momentum = momentum
137 | ctx.eps = eps
138 | ctx.activation = activation
139 | ctx.slope = slope
140 | ctx.affine = weight is not None and bias is not None
141 |
142 | count = _count_samples(x) * (ctx.master_queue.maxsize + 1)
143 | x = x.contiguous()
144 | weight = weight.contiguous() if ctx.affine else x.new_empty(0)
145 | bias = bias.contiguous() if ctx.affine else x.new_empty(0)
146 |
147 | if ctx.training:
148 | mean, var = _backend.mean_var(x)
149 |
150 | if ctx.is_master:
151 | means, vars = [mean.unsqueeze(0)], [var.unsqueeze(0)]
152 | for _ in range(ctx.master_queue.maxsize):
153 | mean_w, var_w = ctx.master_queue.get()
154 | ctx.master_queue.task_done()
155 | means.append(mean_w.unsqueeze(0))
156 | vars.append(var_w.unsqueeze(0))
157 |
158 | means = comm.gather(means)
159 | vars = comm.gather(vars)
160 |
161 | mean = means.mean(0)
162 | var = (vars + (mean - means) ** 2).mean(0)
163 |
164 | tensors = comm.broadcast_coalesced((mean, var), [mean.get_device()] + ctx.worker_ids)
165 | for ts, queue in zip(tensors[1:], ctx.worker_queues):
166 | queue.put(ts)
167 | else:
168 | ctx.master_queue.put((mean, var))
169 | mean, var = ctx.worker_queue.get()
170 | ctx.worker_queue.task_done()
171 |
172 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
173 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))
174 |
175 | ctx.mark_dirty(x, running_mean, running_var)
176 | else:
177 | mean, var = running_mean.contiguous(), running_var.contiguous()
178 | ctx.mark_dirty(x)
179 |
180 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
181 | _act_forward(ctx, x)
182 |
183 | ctx.var = var
184 | ctx.save_for_backward(x, var, weight, bias)
185 | return x
186 |
187 | @staticmethod
188 | @once_differentiable
189 | def backward(ctx, dz):
190 | z, var, weight, bias = ctx.saved_tensors
191 | dz = dz.contiguous()
192 |
193 | _act_backward(ctx, z, dz)
194 |
195 | if ctx.training:
196 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
197 |
198 | if ctx.is_master:
199 | edzs, eydzs = [edz], [eydz]
200 | for _ in range(len(ctx.worker_queues)):
201 | edz_w, eydz_w = ctx.master_queue.get()
202 | ctx.master_queue.task_done()
203 | edzs.append(edz_w)
204 | eydzs.append(eydz_w)
205 |
206 | edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1)
207 | eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1)
208 |
209 | tensors = comm.broadcast_coalesced((edz, eydz), [edz.get_device()] + ctx.worker_ids)
210 | for ts, queue in zip(tensors[1:], ctx.worker_queues):
211 | queue.put(ts)
212 | else:
213 | ctx.master_queue.put((edz, eydz))
214 | edz, eydz = ctx.worker_queue.get()
215 | ctx.worker_queue.task_done()
216 | else:
217 | edz = dz.new_zeros(dz.size(1))
218 | eydz = dz.new_zeros(dz.size(1))
219 |
220 | dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
221 | dweight = dweight if ctx.affine else None
222 | dbias = dbias if ctx.affine else None
223 |
224 | return dx, dweight, dbias, None, None, None, None, None, None, None, None
225 |
226 | @staticmethod
227 | def _parse_extra(ctx, extra):
228 | ctx.is_master = extra["is_master"]
229 | if ctx.is_master:
230 | ctx.master_queue = extra["master_queue"]
231 | ctx.worker_queues = extra["worker_queues"]
232 | ctx.worker_ids = extra["worker_ids"]
233 | else:
234 | ctx.master_queue = extra["master_queue"]
235 | ctx.worker_queue = extra["worker_queue"]
236 |
237 |
238 | inplace_abn = InPlaceABN.apply
239 | inplace_abn_sync = InPlaceABNSync.apply
240 |
241 | __all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"]
242 |
--------------------------------------------------------------------------------
/lib/models/sync_bn/inplace_abn/src/common.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | const int WARP_SIZE = 32;
6 | const int MAX_BLOCK_SIZE = 512;
7 |
8 | template
9 | struct Pair {
10 | T v1, v2;
11 | __device__ Pair() {}
12 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
13 | __device__ Pair(T v) : v1(v), v2(v) {}
14 | __device__ Pair(int v) : v1(v), v2(v) {}
15 | __device__ Pair &operator+=(const Pair &a) {
16 | v1 += a.v1;
17 | v2 += a.v2;
18 | return *this;
19 | }
20 | };
21 |
22 | template
23 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize,
24 | unsigned int mask = 0xffffffff) {
25 | #if CUDART_VERSION >= 9000
26 | return __shfl_xor_sync(mask, value, laneMask, width);
27 | #else
28 | return __shfl_xor(value, laneMask, width);
29 | #endif
30 | }
31 |
32 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
33 |
34 | static int getNumThreads(int nElem) {
35 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
36 | for (int i = 0; i != 5; ++i) {
37 | if (nElem <= threadSizes[i]) {
38 | return threadSizes[i];
39 | }
40 | }
41 | return MAX_BLOCK_SIZE;
42 | }
43 |
44 | template
45 | static __device__ __forceinline__ T warpSum(T val) {
46 | #if __CUDA_ARCH__ >= 300
47 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
48 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
49 | }
50 | #else
51 | __shared__ T values[MAX_BLOCK_SIZE];
52 | values[threadIdx.x] = val;
53 | __threadfence_block();
54 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
55 | for (int i = 1; i < WARP_SIZE; i++) {
56 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
57 | }
58 | #endif
59 | return val;
60 | }
61 |
62 | template
63 | static __device__ __forceinline__ Pair warpSum(Pair value) {
64 | value.v1 = warpSum(value.v1);
65 | value.v2 = warpSum(value.v2);
66 | return value;
67 | }
68 |
69 | template
70 | __device__ T reduce(Op op, int plane, int N, int C, int S) {
71 | T sum = (T)0;
72 | for (int batch = 0; batch < N; ++batch) {
73 | for (int x = threadIdx.x; x < S; x += blockDim.x) {
74 | sum += op(batch, plane, x);
75 | }
76 | }
77 |
78 | sum = warpSum(sum);
79 |
80 | __shared__ T shared[32];
81 | __syncthreads();
82 | if (threadIdx.x % WARP_SIZE == 0) {
83 | shared[threadIdx.x / WARP_SIZE] = sum;
84 | }
85 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
86 | shared[threadIdx.x] = (T)0;
87 | }
88 | __syncthreads();
89 | if (threadIdx.x / WARP_SIZE == 0) {
90 | sum = warpSum(shared[threadIdx.x]);
91 | if (threadIdx.x == 0) {
92 | shared[0] = sum;
93 | }
94 | }
95 | __syncthreads();
96 |
97 | return shared[0];
98 | }
--------------------------------------------------------------------------------
/lib/models/sync_bn/inplace_abn/src/inplace_abn.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include "inplace_abn.h"
6 |
7 | std::vector mean_var(at::Tensor x) {
8 | if (x.is_cuda()) {
9 | return mean_var_cuda(x);
10 | } else {
11 | return mean_var_cpu(x);
12 | }
13 | }
14 |
15 | at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
16 | bool affine, float eps) {
17 | if (x.is_cuda()) {
18 | return forward_cuda(x, mean, var, weight, bias, affine, eps);
19 | } else {
20 | return forward_cpu(x, mean, var, weight, bias, affine, eps);
21 | }
22 | }
23 |
24 | std::vector edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
25 | bool affine, float eps) {
26 | if (z.is_cuda()) {
27 | return edz_eydz_cuda(z, dz, weight, bias, affine, eps);
28 | } else {
29 | return edz_eydz_cpu(z, dz, weight, bias, affine, eps);
30 | }
31 | }
32 |
33 | std::vector backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
34 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
35 | if (z.is_cuda()) {
36 | return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps);
37 | } else {
38 | return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps);
39 | }
40 | }
41 |
42 | void leaky_relu_forward(at::Tensor z, float slope) {
43 | at::leaky_relu_(z, slope);
44 | }
45 |
46 | void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) {
47 | if (z.is_cuda()) {
48 | return leaky_relu_backward_cuda(z, dz, slope);
49 | } else {
50 | return leaky_relu_backward_cpu(z, dz, slope);
51 | }
52 | }
53 |
54 | void elu_forward(at::Tensor z) {
55 | at::elu_(z);
56 | }
57 |
58 | void elu_backward(at::Tensor z, at::Tensor dz) {
59 | if (z.is_cuda()) {
60 | return elu_backward_cuda(z, dz);
61 | } else {
62 | return elu_backward_cpu(z, dz);
63 | }
64 | }
65 |
66 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
67 | m.def("mean_var", &mean_var, "Mean and variance computation");
68 | m.def("forward", &forward, "In-place forward computation");
69 | m.def("edz_eydz", &edz_eydz, "First part of backward computation");
70 | m.def("backward", &backward, "Second part of backward computation");
71 | m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation");
72 | m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion");
73 | m.def("elu_forward", &elu_forward, "Elu forward computation");
74 | m.def("elu_backward", &elu_backward, "Elu backward computation and inversion");
75 | }
--------------------------------------------------------------------------------
/lib/models/sync_bn/inplace_abn/src/inplace_abn.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | #include
6 |
7 | std::vector mean_var_cpu(at::Tensor x);
8 | std::vector mean_var_cuda(at::Tensor x);
9 |
10 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
11 | bool affine, float eps);
12 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
13 | bool affine, float eps);
14 |
15 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
16 | bool affine, float eps);
17 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
18 | bool affine, float eps);
19 |
20 | std::vector backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
21 | at::Tensor edz, at::Tensor eydz, bool affine, float eps);
22 | std::vector backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
23 | at::Tensor edz, at::Tensor eydz, bool affine, float eps);
24 |
25 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope);
26 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope);
27 |
28 | void elu_backward_cpu(at::Tensor z, at::Tensor dz);
29 | void elu_backward_cuda(at::Tensor z, at::Tensor dz);
--------------------------------------------------------------------------------
/lib/models/sync_bn/inplace_abn/src/inplace_abn_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include "inplace_abn.h"
6 |
7 | at::Tensor reduce_sum(at::Tensor x) {
8 | if (x.ndimension() == 2) {
9 | return x.sum(0);
10 | } else {
11 | auto x_view = x.view({x.size(0), x.size(1), -1});
12 | return x_view.sum(-1).sum(0);
13 | }
14 | }
15 |
16 | at::Tensor broadcast_to(at::Tensor v, at::Tensor x) {
17 | if (x.ndimension() == 2) {
18 | return v;
19 | } else {
20 | std::vector broadcast_size = {1, -1};
21 | for (int64_t i = 2; i < x.ndimension(); ++i)
22 | broadcast_size.push_back(1);
23 |
24 | return v.view(broadcast_size);
25 | }
26 | }
27 |
28 | int64_t count(at::Tensor x) {
29 | int64_t count = x.size(0);
30 | for (int64_t i = 2; i < x.ndimension(); ++i)
31 | count *= x.size(i);
32 |
33 | return count;
34 | }
35 |
36 | at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) {
37 | if (affine) {
38 | return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z);
39 | } else {
40 | return z;
41 | }
42 | }
43 |
44 | std::vector mean_var_cpu(at::Tensor x) {
45 | auto num = count(x);
46 | auto mean = reduce_sum(x) / num;
47 | auto diff = x - broadcast_to(mean, x);
48 | auto var = reduce_sum(diff.pow(2)) / num;
49 |
50 | return {mean, var};
51 | }
52 |
53 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
54 | bool affine, float eps) {
55 | auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var);
56 | auto mul = at::rsqrt(var + eps) * gamma;
57 |
58 | x.sub_(broadcast_to(mean, x));
59 | x.mul_(broadcast_to(mul, x));
60 | if (affine) x.add_(broadcast_to(bias, x));
61 |
62 | return x;
63 | }
64 |
65 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
66 | bool affine, float eps) {
67 | auto edz = reduce_sum(dz);
68 | auto y = invert_affine(z, weight, bias, affine, eps);
69 | auto eydz = reduce_sum(y * dz);
70 |
71 | return {edz, eydz};
72 | }
73 |
74 | std::vector backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
75 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
76 | auto y = invert_affine(z, weight, bias, affine, eps);
77 | auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps);
78 |
79 | auto num = count(z);
80 | auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz);
81 |
82 | auto dweight = at::empty(z.type(), {0});
83 | auto dbias = at::empty(z.type(), {0});
84 | if (affine) {
85 | dweight = eydz * at::sign(weight);
86 | dbias = edz;
87 | }
88 |
89 | return {dx, dweight, dbias};
90 | }
91 |
92 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) {
93 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] {
94 | int64_t count = z.numel();
95 | auto *_z = z.data();
96 | auto *_dz = dz.data();
97 |
98 | for (int64_t i = 0; i < count; ++i) {
99 | if (_z[i] < 0) {
100 | _z[i] *= 1 / slope;
101 | _dz[i] *= slope;
102 | }
103 | }
104 | }));
105 | }
106 |
107 | void elu_backward_cpu(at::Tensor z, at::Tensor dz) {
108 | AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] {
109 | int64_t count = z.numel();
110 | auto *_z = z.data();
111 | auto *_dz = dz.data();
112 |
113 | for (int64_t i = 0; i < count; ++i) {
114 | if (_z[i] < 0) {
115 | _z[i] = log1p(_z[i]);
116 | _dz[i] *= (_z[i] + 1.f);
117 | }
118 | }
119 | }));
120 | }
--------------------------------------------------------------------------------
/lib/models/sync_bn/inplace_abn/src/inplace_abn_cuda.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 |
6 | #include
7 |
8 | #include "common.h"
9 | #include "inplace_abn.h"
10 |
11 | // Checks
12 | #ifndef AT_CHECK
13 | #define AT_CHECK AT_ASSERT
14 | #endif
15 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
16 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
17 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
18 |
19 | // Utilities
20 | void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) {
21 | num = x.size(0);
22 | chn = x.size(1);
23 | sp = 1;
24 | for (int64_t i = 2; i < x.ndimension(); ++i)
25 | sp *= x.size(i);
26 | }
27 |
28 | // Operations for reduce
29 | template
30 | struct SumOp {
31 | __device__ SumOp(const T *t, int c, int s)
32 | : tensor(t), chn(c), sp(s) {}
33 | __device__ __forceinline__ T operator()(int batch, int plane, int n) {
34 | return tensor[(batch * chn + plane) * sp + n];
35 | }
36 | const T *tensor;
37 | const int chn;
38 | const int sp;
39 | };
40 |
41 | template
42 | struct VarOp {
43 | __device__ VarOp(T m, const T *t, int c, int s)
44 | : mean(m), tensor(t), chn(c), sp(s) {}
45 | __device__ __forceinline__ T operator()(int batch, int plane, int n) {
46 | T val = tensor[(batch * chn + plane) * sp + n];
47 | return (val - mean) * (val - mean);
48 | }
49 | const T mean;
50 | const T *tensor;
51 | const int chn;
52 | const int sp;
53 | };
54 |
55 | template
56 | struct GradOp {
57 | __device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s)
58 | : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
59 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) {
60 | T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight;
61 | T _dz = dz[(batch * chn + plane) * sp + n];
62 | return Pair(_dz, _y * _dz);
63 | }
64 | const T weight;
65 | const T bias;
66 | const T *z;
67 | const T *dz;
68 | const int chn;
69 | const int sp;
70 | };
71 |
72 | /***********
73 | * mean_var
74 | ***********/
75 |
76 | template
77 | __global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) {
78 | int plane = blockIdx.x;
79 | T norm = T(1) / T(num * sp);
80 |
81 | T _mean = reduce>(SumOp(x, chn, sp), plane, num, chn, sp) * norm;
82 | __syncthreads();
83 | T _var = reduce>(VarOp(_mean, x, chn, sp), plane, num, chn, sp) * norm;
84 |
85 | if (threadIdx.x == 0) {
86 | mean[plane] = _mean;
87 | var[plane] = _var;
88 | }
89 | }
90 |
91 | std::vector mean_var_cuda(at::Tensor x) {
92 | CHECK_INPUT(x);
93 |
94 | // Extract dimensions
95 | int64_t num, chn, sp;
96 | get_dims(x, num, chn, sp);
97 |
98 | // Prepare output tensors
99 | auto mean = at::empty(x.type(), {chn});
100 | auto var = at::empty(x.type(), {chn});
101 |
102 | // Run kernel
103 | dim3 blocks(chn);
104 | dim3 threads(getNumThreads(sp));
105 | AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] {
106 | mean_var_kernel<<>>(
107 | x.data(),
108 | mean.data(),
109 | var.data(),
110 | num, chn, sp);
111 | }));
112 |
113 | return {mean, var};
114 | }
115 |
116 | /**********
117 | * forward
118 | **********/
119 |
120 | template
121 | __global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias,
122 | bool affine, float eps, int num, int chn, int sp) {
123 | int plane = blockIdx.x;
124 |
125 | T _mean = mean[plane];
126 | T _var = var[plane];
127 | T _weight = affine ? abs(weight[plane]) + eps : T(1);
128 | T _bias = affine ? bias[plane] : T(0);
129 |
130 | T mul = rsqrt(_var + eps) * _weight;
131 |
132 | for (int batch = 0; batch < num; ++batch) {
133 | for (int n = threadIdx.x; n < sp; n += blockDim.x) {
134 | T _x = x[(batch * chn + plane) * sp + n];
135 | T _y = (_x - _mean) * mul + _bias;
136 |
137 | x[(batch * chn + plane) * sp + n] = _y;
138 | }
139 | }
140 | }
141 |
142 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
143 | bool affine, float eps) {
144 | CHECK_INPUT(x);
145 | CHECK_INPUT(mean);
146 | CHECK_INPUT(var);
147 | CHECK_INPUT(weight);
148 | CHECK_INPUT(bias);
149 |
150 | // Extract dimensions
151 | int64_t num, chn, sp;
152 | get_dims(x, num, chn, sp);
153 |
154 | // Run kernel
155 | dim3 blocks(chn);
156 | dim3 threads(getNumThreads(sp));
157 | AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] {
158 | forward_kernel<<>>(
159 | x.data(),
160 | mean.data(),
161 | var.data(),
162 | weight.data(),
163 | bias.data(),
164 | affine, eps, num, chn, sp);
165 | }));
166 |
167 | return x;
168 | }
169 |
170 | /***********
171 | * edz_eydz
172 | ***********/
173 |
174 | template
175 | __global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias,
176 | T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) {
177 | int plane = blockIdx.x;
178 |
179 | T _weight = affine ? abs(weight[plane]) + eps : 1.f;
180 | T _bias = affine ? bias[plane] : 0.f;
181 |
182 | Pair res = reduce, GradOp>(GradOp(_weight, _bias, z, dz, chn, sp), plane, num, chn, sp);
183 | __syncthreads();
184 |
185 | if (threadIdx.x == 0) {
186 | edz[plane] = res.v1;
187 | eydz[plane] = res.v2;
188 | }
189 | }
190 |
191 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
192 | bool affine, float eps) {
193 | CHECK_INPUT(z);
194 | CHECK_INPUT(dz);
195 | CHECK_INPUT(weight);
196 | CHECK_INPUT(bias);
197 |
198 | // Extract dimensions
199 | int64_t num, chn, sp;
200 | get_dims(z, num, chn, sp);
201 |
202 | auto edz = at::empty(z.type(), {chn});
203 | auto eydz = at::empty(z.type(), {chn});
204 |
205 | // Run kernel
206 | dim3 blocks(chn);
207 | dim3 threads(getNumThreads(sp));
208 | AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] {
209 | edz_eydz_kernel<<>>(
210 | z.data(),
211 | dz.data(),
212 | weight.data(),
213 | bias.data(),
214 | edz.data(),
215 | eydz.data(),
216 | affine, eps, num, chn, sp);
217 | }));
218 |
219 | return {edz, eydz};
220 | }
221 |
222 | /***********
223 | * backward
224 | ***********/
225 |
226 | template
227 | __global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz,
228 | const T *eydz, T *dx, T *dweight, T *dbias,
229 | bool affine, float eps, int num, int chn, int sp) {
230 | int plane = blockIdx.x;
231 |
232 | T _weight = affine ? abs(weight[plane]) + eps : 1.f;
233 | T _bias = affine ? bias[plane] : 0.f;
234 | T _var = var[plane];
235 | T _edz = edz[plane];
236 | T _eydz = eydz[plane];
237 |
238 | T _mul = _weight * rsqrt(_var + eps);
239 | T count = T(num * sp);
240 |
241 | for (int batch = 0; batch < num; ++batch) {
242 | for (int n = threadIdx.x; n < sp; n += blockDim.x) {
243 | T _dz = dz[(batch * chn + plane) * sp + n];
244 | T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight;
245 |
246 | dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul;
247 | }
248 | }
249 |
250 | if (threadIdx.x == 0) {
251 | if (affine) {
252 | dweight[plane] = weight[plane] > 0 ? _eydz : -_eydz;
253 | dbias[plane] = _edz;
254 | }
255 | }
256 | }
257 |
258 | std::vector backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
259 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
260 | CHECK_INPUT(z);
261 | CHECK_INPUT(dz);
262 | CHECK_INPUT(var);
263 | CHECK_INPUT(weight);
264 | CHECK_INPUT(bias);
265 | CHECK_INPUT(edz);
266 | CHECK_INPUT(eydz);
267 |
268 | // Extract dimensions
269 | int64_t num, chn, sp;
270 | get_dims(z, num, chn, sp);
271 |
272 | auto dx = at::zeros_like(z);
273 | auto dweight = at::zeros_like(weight);
274 | auto dbias = at::zeros_like(bias);
275 |
276 | // Run kernel
277 | dim3 blocks(chn);
278 | dim3 threads(getNumThreads(sp));
279 | AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] {
280 | backward_kernel<<>>(
281 | z.data(),
282 | dz.data(),
283 | var.data(),
284 | weight.data(),
285 | bias.data(),
286 | edz.data(),
287 | eydz.data(),
288 | dx.data(),
289 | dweight.data(),
290 | dbias.data(),
291 | affine, eps, num, chn, sp);
292 | }));
293 |
294 | return {dx, dweight, dbias};
295 | }
296 |
297 | /**************
298 | * activations
299 | **************/
300 |
301 | template
302 | inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) {
303 | // Create thrust pointers
304 | thrust::device_ptr th_z = thrust::device_pointer_cast(z);
305 | thrust::device_ptr th_dz = thrust::device_pointer_cast(dz);
306 |
307 | thrust::transform_if(th_dz, th_dz + count, th_z, th_dz,
308 | [slope] __device__ (const T& dz) { return dz * slope; },
309 | [] __device__ (const T& z) { return z < 0; });
310 | thrust::transform_if(th_z, th_z + count, th_z,
311 | [slope] __device__ (const T& z) { return z / slope; },
312 | [] __device__ (const T& z) { return z < 0; });
313 | }
314 |
315 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) {
316 | CHECK_INPUT(z);
317 | CHECK_INPUT(dz);
318 |
319 | int64_t count = z.numel();
320 |
321 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
322 | leaky_relu_backward_impl(z.data(), dz.data(), slope, count);
323 | }));
324 | }
325 |
326 | template
327 | inline void elu_backward_impl(T *z, T *dz, int64_t count) {
328 | // Create thrust pointers
329 | thrust::device_ptr th_z = thrust::device_pointer_cast(z);
330 | thrust::device_ptr th_dz = thrust::device_pointer_cast(dz);
331 |
332 | thrust::transform_if(th_dz, th_dz + count, th_z, th_z, th_dz,
333 | [] __device__ (const T& dz, const T& z) { return dz * (z + 1.); },
334 | [] __device__ (const T& z) { return z < 0; });
335 | thrust::transform_if(th_z, th_z + count, th_z,
336 | [] __device__ (const T& z) { return log1p(z); },
337 | [] __device__ (const T& z) { return z < 0; });
338 | }
339 |
340 | void elu_backward_cuda(at::Tensor z, at::Tensor dz) {
341 | CHECK_INPUT(z);
342 | CHECK_INPUT(dz);
343 |
344 | int64_t count = z.numel();
345 |
346 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
347 | elu_backward_impl(z.data(), dz.data(), count);
348 | }));
349 | }
350 |
--------------------------------------------------------------------------------
/lib/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuzhuang13/anytime/d2b58de8b7f99c14550e2ae7a715ad68736f846d/lib/utils/__init__.py
--------------------------------------------------------------------------------
/lib/utils/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def _fast_hist(label_true, label_pred, n_class):
5 | mask = (label_true >= 0) & (label_true < n_class)
6 | hist = np.bincount(
7 | n_class * label_true[mask].astype(int) + label_pred[mask],
8 | minlength=n_class ** 2,
9 | ).reshape(n_class, n_class)
10 | return hist
11 |
12 |
13 | def scores(label_trues, label_preds, n_class):
14 | hist = np.zeros((n_class, n_class))
15 | for lt, lp in zip(label_trues, label_preds):
16 | hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
17 | acc = np.diag(hist).sum() / hist.sum()
18 | acc_cls = np.diag(hist) / hist.sum(axis=1)
19 | acc_cls = np.nanmean(acc_cls)
20 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
21 | valid = hist.sum(axis=1) > 0
22 | mean_iu = np.nanmean(iu[valid])
23 | freq = hist.sum(axis=1) / hist.sum()
24 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
25 | cls_iu = dict(zip(range(n_class), iu))
26 |
27 | return {
28 | "Overall Acc": acc,
29 | "Mean Acc": acc_cls,
30 | "FreqW Acc": fwavacc,
31 | "Mean IoU": mean_iu,
32 | "Class IoU": cls_iu,
33 | }
34 |
35 |
36 | def batch_pix_accuracy(output, target):
37 | _, predict = torch.max(output, 1)
38 |
39 | predict = predict.cpu().numpy().astype('int64') + 1
40 | target = target.cpu().numpy().astype('int64') + 1
41 |
42 | pixel_labeled = np.sum(target > 0)
43 | pixel_correct = np.sum((predict == target)*(target > 0))
44 | assert pixel_correct <= pixel_labeled, \
45 | "Correct area should be smaller than Labeled"
46 | return pixel_correct, pixel_labeled
47 |
48 |
49 | def batch_intersection_union(output, target, nclass):
50 | _, predict = torch.max(output, 1)
51 | mini = 1
52 | maxi = nclass
53 | nbins = nclass
54 | predict = predict.cpu().numpy().astype('int64') + 1
55 | target = target.cpu().numpy().astype('int64') + 1
56 |
57 | predict = predict * (target > 0).astype(predict.dtype)
58 | intersection = predict * (predict == target)
59 | area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
60 | area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
61 | area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
62 | area_union = area_pred + area_lab - area_inter
63 | assert (area_inter <= area_union).all(), \
64 | "Intersection area should be smaller than Union area"
65 | return area_inter, area_union
66 |
67 |
68 | def pixel_accuracy(im_pred, im_lab):
69 | im_pred = np.asarray(im_pred)
70 | im_lab = np.asarray(im_lab)
71 | pixel_labeled = np.sum(im_lab > 0)
72 | pixel_correct = np.sum((im_pred == im_lab) * (im_lab > 0))
73 | return pixel_correct, pixel_labeled
74 |
75 |
76 | def intersection_and_union(im_pred, im_lab, num_class):
77 | im_pred = np.asarray(im_pred)
78 | im_lab = np.asarray(im_lab)
79 | im_pred = im_pred * (im_lab > 0)
80 | intersection = im_pred * (im_pred == im_lab)
81 | area_inter, _ = np.histogram(intersection, bins=num_class-1,
82 | range=(1, num_class - 1))
83 | area_pred, _ = np.histogram(im_pred, bins=num_class-1,
84 | range=(1, num_class - 1))
85 | area_lab, _ = np.histogram(im_lab, bins=num_class-1,
86 | range=(1, num_class - 1))
87 | area_union = area_pred + area_lab - area_inter
88 | return area_inter, area_union
89 |
--------------------------------------------------------------------------------
/lib/utils/modelsummary.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import logging
7 | from collections import namedtuple
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 | def get_model_summary(model, *input_tensors, item_length=26, verbose=False):
13 | summary = []
14 |
15 | ModuleDetails = namedtuple(
16 | "Layer", ["name", "input_size", "output_size", "num_parameters", "multiply_adds", "num_param_counts"])
17 | hooks = []
18 | layer_instances = {}
19 |
20 | def add_hooks(module):
21 |
22 | def hook(module, input, output):
23 | class_name = str(module.__class__.__name__)
24 |
25 | instance_index = 1
26 | if class_name not in layer_instances:
27 | layer_instances[class_name] = instance_index
28 | else:
29 | instance_index = layer_instances[class_name] + 1
30 | layer_instances[class_name] = instance_index
31 |
32 | layer_name = class_name + "_" + str(instance_index)
33 |
34 | params = 0
35 | counts = 0
36 | if class_name.find("Conv") != -1 or class_name.find("BatchNorm") != -1 or \
37 | class_name.find("Linear") != -1 or class_name.find('conv') != -1 or class_name.find('Temp') != -1:
38 | for param_ in module.parameters():
39 | if param_.requires_grad:
40 | params += param_.view(-1).size(0)
41 | counts += 1
42 |
43 | flops = "Not Available"
44 |
45 | if (class_name.find("ConvTranspose2d") != -1) and hasattr(module, "weight"):
46 | flops = (
47 | torch.prod(
48 | torch.LongTensor(list(module.weight.data.size()))) *
49 | torch.prod(
50 | torch.LongTensor(list(input[0].size())[2:]))).item()
51 |
52 |
53 | elif (class_name.find("Conv") != -1) and hasattr(module, "weight"):
54 | flops = (
55 | torch.prod(
56 | torch.LongTensor(list(module.weight.data.size()))) *
57 | torch.prod(
58 | torch.LongTensor(list(output.size())[2:]))).item()
59 |
60 |
61 | elif class_name.find("conv")!= -1:
62 | flops = (
63 | torch.prod(
64 | torch.LongTensor(list(module.weight.data.size()))) *
65 | torch.prod(
66 | torch.LongTensor(list(output.size())[2:]))).item()
67 | p = module.mask.mean().item()
68 | p2 = module.p
69 | flops = int(flops * p)
70 | if module.interpolate == 'rbf' or module.interpolate == 'pooling' or module.interpolate == 'conv':
71 | flops += int((1 - p) * p * output.size(1) * (module.r * module.r) * output.size(2) * output.size(3))
72 | elif isinstance(module, nn.Linear):
73 | flops = (torch.prod(torch.LongTensor(list(output.size()))) \
74 | * input[0].size(1)).item()
75 |
76 | if isinstance(input[0], list):
77 | input = input[0]
78 | if isinstance(output, list):
79 | output = output[0]
80 |
81 | summary.append(
82 | ModuleDetails(
83 | name=layer_name,
84 | input_size=list(input[0].size()),
85 | output_size=list(output.size()),
86 | num_parameters=params,
87 | multiply_adds=flops,
88 | num_param_counts=counts,)
89 | )
90 |
91 | if not isinstance(module, nn.ModuleList) \
92 | and not isinstance(module, nn.Sequential) \
93 | and module != model:
94 | hooks.append(module.register_forward_hook(hook))
95 |
96 | model.eval()
97 | model.apply(add_hooks)
98 |
99 | space_len = item_length
100 |
101 | model(*input_tensors)
102 | for hook in hooks:
103 | hook.remove()
104 |
105 | details = ''
106 | if verbose:
107 | details = "Model Summary" + \
108 | os.linesep + \
109 | "Name{}Input Size{}Output Size{}Parameters{}Multiply Adds (Flops){}".format(
110 | ' ' * (space_len - len("Name")),
111 | ' ' * (space_len - len("Input Size")),
112 | ' ' * (space_len - len("Output Size")),
113 | ' ' * (space_len - len("Parameters")),
114 | ' ' * (space_len - len("Multiply Adds (Flops)"))) \
115 | + os.linesep + '-' * space_len * 5 + os.linesep
116 |
117 | params_sum = 0
118 | flops_sum = 0
119 | counts_sum = 0
120 | for layer in summary:
121 | params_sum += layer.num_parameters
122 | if layer.multiply_adds != "Not Available":
123 | flops_sum += layer.multiply_adds
124 | counts_sum += layer.num_param_counts
125 | shown_flops = layer.multiply_adds/(1024**3) if layer.multiply_adds != 'Not Available' else 0
126 | if verbose:
127 | if shown_flops < 1:
128 | details += ''
129 | else:
130 | details += "{}{}{}{}{}{}{}{}{}{}".format(
131 | layer.name,
132 | ' ' * (space_len - len(layer.name)),
133 | layer.input_size,
134 | ' ' * (space_len - len(str(layer.input_size))),
135 | layer.output_size,
136 | ' ' * (space_len - len(str(layer.output_size))),
137 | layer.num_parameters,
138 | ' ' * (space_len - len(str(layer.num_parameters))),
139 | shown_flops,
140 | ' ' * (space_len - len(str(shown_flops)))) \
141 | + os.linesep + '-' * space_len * 5 + os.linesep
142 |
143 | details += os.linesep \
144 | + "Total Parameters: {:,}".format(params_sum) \
145 | + os.linesep + '-' * space_len * 5 + os.linesep
146 | details += "Total Multiply Adds (For Convolution and Linear Layers only): {:,} GFLOPs".format(flops_sum/(1024**3)) \
147 | + os.linesep + '-' * space_len * 5 + os.linesep
148 | details += "Number of Layers" + os.linesep
149 | for layer in layer_instances:
150 | details += "{} : {} layers ".format(layer, layer_instances[layer])
151 |
152 | return details, {'params': params_sum,
153 | 'flops': flops_sum,
154 | 'counts': counts_sum}
155 |
--------------------------------------------------------------------------------
/lib/utils/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import logging
7 | import time
8 | from pathlib import Path
9 |
10 | import numpy as np
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 |
16 | import pdb, time
17 |
18 | def json_save(filename, json_obj):
19 | import json
20 | with open(filename, 'w') as f:
21 | json.dump(json_obj, f, indent=4)
22 |
23 | def json_read(filename):
24 | import json
25 | with open(filename, 'r') as f:
26 | data = json.load(f)
27 |
28 | return data
29 |
30 |
31 | class FullModel(nn.Module):
32 | def __init__(self, model, loss):
33 | super(FullModel, self).__init__()
34 | self.model = model
35 | self.loss = loss
36 |
37 | def forward(self, inputs, labels):
38 | outputs = self.model(inputs)
39 | loss = self.loss(outputs, labels)
40 | return torch.unsqueeze(loss,0), outputs
41 |
42 |
43 | class FullEEModel(nn.Module):
44 | def __init__(self, model, loss, config=None):
45 | super(FullEEModel, self).__init__()
46 | self.model = model
47 | self.loss = loss
48 | self.cfg = config
49 |
50 | def forward(self, inputs, labels):
51 | outputs = self.model(inputs)
52 | losses = []
53 | for i, output in enumerate(outputs):
54 | losses.append(self.loss(outputs[i], labels))
55 |
56 | return losses, outputs
57 |
58 | def get_world_size():
59 | if not torch.distributed.is_initialized():
60 | return 1
61 | return torch.distributed.get_world_size()
62 |
63 | def get_rank():
64 | if not torch.distributed.is_initialized():
65 | return 0
66 | return torch.distributed.get_rank()
67 |
68 | class AverageMeter(object):
69 | def __init__(self):
70 | self.initialized = False
71 | self.val = None
72 | self.avg = None
73 | self.sum = None
74 | self.count = None
75 |
76 | def initialize(self, val, weight):
77 | self.val = val
78 | self.avg = val
79 | self.sum = val * weight
80 | self.count = weight
81 | self.initialized = True
82 |
83 | def update(self, val, weight=1):
84 | if not self.initialized:
85 | self.initialize(val, weight)
86 | else:
87 | self.add(val, weight)
88 |
89 | def add(self, val, weight):
90 | self.val = val
91 | self.sum += val * weight
92 | self.count += weight
93 | self.avg = self.sum / self.count
94 |
95 | def value(self):
96 | return self.val
97 |
98 | def average(self):
99 | return self.avg
100 |
101 | def create_logger(cfg, cfg_name, phase='train'):
102 | root_output_dir = Path(cfg.OUTPUT_DIR)
103 | os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
104 |
105 | dataset = cfg.DATASET.DATASET
106 | model = cfg.MODEL.NAME
107 | cfg_name = os.path.basename(cfg_name).split('.')[0]
108 | final_output_dir = root_output_dir
109 | os.makedirs(final_output_dir, exist_ok=True)
110 |
111 | print('=> creating {}'.format(final_output_dir))
112 | final_output_dir.mkdir(parents=True, exist_ok=True)
113 |
114 | time_str = time.strftime('%Y-%m-%d-%H-%M')
115 | log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase)
116 | final_log_file = final_output_dir / log_file
117 | head = '%(asctime)-15s %(message)s'
118 | logging.basicConfig(filename=str(final_log_file),
119 | format=head)
120 | logger = logging.getLogger()
121 | logger.setLevel(logging.INFO)
122 | console = logging.StreamHandler()
123 | logging.getLogger('').addHandler(console)
124 | tensorboard_log_dir = final_output_dir
125 | print('=> creating {}'.format(tensorboard_log_dir))
126 | tensorboard_log_dir.mkdir(parents=True, exist_ok=True)
127 |
128 | return logger, str(final_output_dir), str(tensorboard_log_dir)
129 |
130 |
131 | def get_confusion_matrix_gpu(label, pred, size, num_class, ignore=-1, device=0):
132 | output = pred.transpose(1,3).transpose(1,2)
133 | seg_pred = torch.max(output, dim=3)[1]
134 | seg_gt = label
135 |
136 | ignore_index = seg_gt != ignore
137 |
138 | seg_gt = seg_gt[ignore_index]
139 |
140 | seg_pred = seg_pred[ignore_index]
141 | if seg_gt.get_device() == -1:
142 | seg_gt = seg_gt.to(0)
143 |
144 | index = ((seg_gt * num_class).long()+ seg_pred)
145 | label_count = torch.bincount(index)
146 | confusion_matrix = np.zeros((num_class, num_class))
147 |
148 | for i_label in range(num_class):
149 | for i_pred in range(num_class):
150 | cur_index = i_label * num_class + i_pred
151 | if cur_index < len(label_count):
152 | confusion_matrix[i_label,
153 | i_pred] = label_count[cur_index]
154 | return confusion_matrix
155 |
156 | def adjust_learning_rate(optimizer, base_lr, max_iters,
157 | cur_iters, power=0.9):
158 | lr = base_lr*((1-float(cur_iters)/max_iters)**(power))
159 | for i, param in enumerate(optimizer.param_groups):
160 | optimizer.param_groups[i]['lr'] = lr
161 | return lr
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cityscapesscripts
2 | numpy
3 | Pillow
4 | tensorboardX
5 | tqdm
6 | yacs
--------------------------------------------------------------------------------
/tools/_init_paths.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os.path as osp
6 | import sys
7 |
8 |
9 | def add_path(path):
10 | if path not in sys.path:
11 | sys.path.insert(0, path)
12 |
13 | this_dir = osp.dirname(__file__)
14 |
15 | lib_path = osp.join(this_dir, '..', 'lib')
16 | add_path(lib_path)
17 |
--------------------------------------------------------------------------------
/tools/test_ee.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pprint
4 | import shutil
5 | import sys
6 |
7 | import logging
8 | import time
9 | import timeit
10 | from pathlib import Path
11 |
12 | import numpy as np
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.backends.cudnn as cudnn
17 |
18 | import _init_paths
19 | import models
20 | import datasets
21 | from config import config
22 | from config import update_config
23 | from core.function import testval_ee, testval_ee_profiling, testval_ee_profiling_actual
24 | from utils.modelsummary import get_model_summary
25 | from utils.utils import create_logger, FullModel, FullEEModel, json_save
26 |
27 | import pdb
28 |
29 | def parse_args():
30 | parser = argparse.ArgumentParser(description='Train segmentation network')
31 |
32 | parser.add_argument('--cfg',
33 | help='experiment configure file name',
34 | required=True,
35 | type=str)
36 | parser.add_argument('opts',
37 | help="Modify config options using the command-line",
38 | default=None,
39 | nargs=argparse.REMAINDER)
40 |
41 | args = parser.parse_args()
42 |
43 | update_config(config, args)
44 |
45 | return args
46 |
47 | def main():
48 | args = parse_args()
49 |
50 | config.defrost()
51 | config.OUTPUT_DIR = args.cfg[:-len('config.yaml')]
52 | try:
53 | if config.TEST.SUB_DIR:
54 | config.OUTPUT_DIR = os.path.join(config.OUTPUT_DIR, config.TEST.SUB_DIR)
55 | except:
56 | pass
57 | config.freeze()
58 |
59 | logger, final_output_dir, _ = create_logger(
60 | config, args.cfg, 'test')
61 |
62 | logger.info(pprint.pformat(args))
63 | logger.info(pprint.pformat(config))
64 |
65 | cudnn.benchmark = config.CUDNN.BENCHMARK
66 | cudnn.deterministic = config.CUDNN.DETERMINISTIC
67 | cudnn.enabled = config.CUDNN.ENABLED
68 |
69 | model = eval('models.'+config.MODEL.NAME +
70 | '.get_seg_model')(config)
71 |
72 | device = 0
73 | model.eval()
74 |
75 | dump_input = torch.rand(
76 | (1, 3, config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
77 | )
78 |
79 | if config.PYRAMID_TEST.USE:
80 | dump_input = torch.rand(
81 | (1, 3, config.PYRAMID_TEST.SIZE, config.PYRAMID_TEST.SIZE // 2)
82 | )
83 | dump_output = model.to(device)(dump_input.to(device))
84 | del dump_output
85 | dump_output = model.to(device)(dump_input.to(device))
86 |
87 | if not (config.MASK.USE and (config.MASK.CRIT == 'conf_thre' or config.MASK.CRIT == 'entropy_thre')):
88 | stats = {}
89 | saved_stats = {}
90 | for i in range(4):
91 | setattr(model, f"stop{i+1}", "anY_RanDOM_ThiNg")
92 | summary, stats[i+1] = get_model_summary(model.to(device), dump_input.to(device), verbose=True)
93 | delattr(model, f"stop{i+1}")
94 |
95 | logger.info(f'\n\n>>>>>>>>>>>>>>>>>>>>>>> EXIT {i+1} >>>>>>>>>>>>>>>>>>>>>>>>>> ')
96 | logger.info(summary)
97 |
98 | saved_stats['params'] = [stats[i+1]['params'] for i in range(4)]
99 | saved_stats['flops'] = [stats[i+1]['flops'] for i in range(4)]
100 | saved_stats['counts'] = [stats[i+1]['counts'] for i in range(4)]
101 | saved_stats['Gflops'] = [f/(1024**3) for f in saved_stats['flops']]
102 | saved_stats['Gflops_mean'] = np.mean(saved_stats['Gflops'])
103 | saved_stats['Mparams'] = [f/(10**6) for f in saved_stats['params']]
104 | json_save(os.path.join(final_output_dir, 'test_stats.json'), saved_stats)
105 |
106 | if config.TEST.MODEL_FILE:
107 | model_state_file = config.TEST.MODEL_FILE
108 | else:
109 | model_state_file = os.path.join(final_output_dir,
110 | 'final_state.pth')
111 |
112 | try:
113 | if config.TEST.SUB_DIR:
114 | model_state_file = args.cfg[:-len('config.yaml')] + 'final_state.pth'
115 | except:
116 | pass
117 |
118 | logger.info('=> loading model from {}'.format(model_state_file))
119 |
120 | pretrained_dict = torch.load(model_state_file)
121 | model_dict = model.state_dict()
122 | pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
123 | if k[6:] in model_dict.keys()}
124 | model_dict.update(pretrained_dict)
125 | model.load_state_dict(model_dict)
126 |
127 | gpus = [0]
128 |
129 | model = nn.DataParallel(model, device_ids=gpus).cuda()
130 | test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
131 | test_dataset = eval('datasets.'+config.DATASET.DATASET)(
132 | root=config.DATASET.ROOT,
133 | list_path=config.DATASET.TEST_SET,
134 | num_samples=None,
135 | num_classes=config.DATASET.NUM_CLASSES,
136 | multi_scale=False,
137 | flip=False,
138 | ignore_label=config.TRAIN.IGNORE_LABEL,
139 | base_size=config.TEST.BASE_SIZE,
140 | crop_size=test_size,
141 | downsample_rate=1)
142 |
143 | testloader = torch.utils.data.DataLoader(
144 | test_dataset,
145 | batch_size=1,
146 | shuffle=False,
147 | num_workers=config.WORKERS,
148 | pin_memory=True)
149 |
150 | start = timeit.default_timer()
151 |
152 | if 'val' in config.DATASET.TEST_SET:
153 | results = testval_ee(config,
154 | test_dataset,
155 | testloader,
156 | model, sv_dir=final_output_dir, sv_pred=True)
157 |
158 | if config.MASK.USE and config.MASK.CRIT == 'conf_thre':
159 | results_profiling = testval_ee_profiling(config,
160 | test_dataset,
161 | testloader,
162 | model, sv_dir=final_output_dir, sv_pred=True)
163 | json_save(os.path.join(final_output_dir, 'test_stats.json'), results_profiling)
164 |
165 | mean_IoUs = []
166 | for i, result in enumerate(results):
167 | mean_IoU, IoU_array, pixel_acc, mean_acc = result
168 |
169 | msg = 'Exit: {}, MeanIU: {: 4.4f}, Pixel_Acc: {: 4.4f}, \
170 | Mean_Acc: {: 4.4f}, Class IoU: '.format(i+1, mean_IoU,
171 | pixel_acc, mean_acc)
172 | logging.info(msg)
173 | logging.info(IoU_array)
174 |
175 | mean_IoUs.append(mean_IoU)
176 |
177 |
178 | mean_IoUs.append(np.mean(mean_IoUs))
179 | print_result = '\t'.join(['{:.2f}'.format(m*100) for m in mean_IoUs])
180 | result_file_name = f'{final_output_dir}/result.txt'
181 |
182 | with open(result_file_name, 'w') as f:
183 | f.write(print_result)
184 |
185 | end = timeit.default_timer()
186 | logger.info('Mins: %d' % np.int((end-start)/60))
187 | logger.info('Done')
188 | logging.info(print_result)
189 |
190 | if __name__ == '__main__':
191 | main()
192 |
--------------------------------------------------------------------------------
/tools/train_ee.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pprint
4 | import shutil
5 | import sys
6 |
7 | import logging
8 | import time
9 | import timeit
10 | from pathlib import Path
11 |
12 | import numpy as np
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.backends.cudnn as cudnn
17 | import torch.optim
18 | from torch.utils.data.distributed import DistributedSampler
19 | from tensorboardX import SummaryWriter
20 |
21 | import _init_paths
22 | import models
23 | import datasets
24 | from config import config
25 | from config import update_config
26 | from core.criterion import CrossEntropy
27 | from core.function import train_ee, validate_ee
28 | from utils.modelsummary import get_model_summary
29 | from utils.utils import create_logger, FullModel, FullEEModel, get_rank, json_save
30 |
31 | import pdb, time
32 | import subprocess
33 |
34 | def parse_args():
35 | parser = argparse.ArgumentParser(description='Train segmentation network')
36 |
37 | parser.add_argument('--cfg',
38 | help='experiment configure file name',
39 | required=True,
40 | type=str)
41 | parser.add_argument("--local_rank", type=int, default=0)
42 | parser.add_argument('opts',
43 | help="Modify config options using the command-line",
44 | default=None,
45 | nargs=argparse.REMAINDER)
46 |
47 | args = parser.parse_args()
48 | update_config(config, args)
49 |
50 | return args
51 |
52 | def main():
53 | args = parse_args()
54 |
55 | logger, final_output_dir, tb_log_dir = create_logger(
56 | config, args.cfg, 'train')
57 |
58 | if args.local_rank == 0:
59 | logger.info(config)
60 |
61 | writer_dict = {
62 | 'writer': SummaryWriter(tb_log_dir),
63 | 'train_global_steps': 0,
64 | 'valid_global_steps': 0,
65 | }
66 |
67 | cudnn.benchmark = config.CUDNN.BENCHMARK
68 | cudnn.deterministic = config.CUDNN.DETERMINISTIC
69 | cudnn.enabled = config.CUDNN.ENABLED
70 |
71 | gpus = list(config.GPUS)
72 | distributed = len(gpus) > 1
73 | device = torch.device('cuda:{}'.format(args.local_rank))
74 |
75 | model = eval('models.'+config.MODEL.NAME +
76 | '.get_seg_model')(config)
77 |
78 |
79 |
80 | if args.local_rank == 0:
81 | with open(f"{final_output_dir}/config.yaml", "w") as f:
82 | f.write(config.dump())
83 |
84 | this_dir = os.path.dirname(__file__)
85 | models_dst_dir = os.path.join(final_output_dir, 'code')
86 | if os.path.exists(models_dst_dir):
87 | shutil.rmtree(models_dst_dir)
88 | shutil.copytree(os.path.join(this_dir, '../lib'), os.path.join(models_dst_dir, 'lib'))
89 | shutil.copytree(os.path.join(this_dir, '../tools'), os.path.join(models_dst_dir, 'tools'))
90 | shutil.copytree(os.path.join(this_dir, '../scripts'), os.path.join(models_dst_dir, 'scripts'))
91 | shutil.copytree(os.path.join(this_dir, '../experiments'), os.path.join(models_dst_dir, 'experiments'))
92 |
93 | if True:
94 | model.eval()
95 | dump_input = torch.rand(
96 | (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
97 | )
98 | dump_output = model.to(device)(dump_input.to(device))
99 |
100 | dump_output = model.to(device)(dump_input.to(device))
101 |
102 | stats = {}
103 | saved_stats = {}
104 | for i in range(4):
105 | setattr(model, f"stop{i+1}", "anY_RanDOM_ThiNg")
106 | summary, stats[i+1] = get_model_summary(model.to(device), dump_input.to(device), verbose=False)
107 | delattr(model, f"stop{i+1}")
108 |
109 | if args.local_rank == 0:
110 | logger.info(f'\n\n>>>>>>>>>>>>>>>>>>>>>>> EXIT {i+1} >>>>>>>>>>>>>>>>>>>>>>>>>> ')
111 | logger.info(summary)
112 |
113 | saved_stats['params'] = [stats[i+1]['params'] for i in range(4)]
114 | saved_stats['flops'] = [stats[i+1]['flops'] for i in range(4)]
115 | saved_stats['counts'] = [stats[i+1]['counts'] for i in range(4)]
116 | saved_stats['Gflops'] = [f/(1024**3) for f in saved_stats['flops']]
117 | saved_stats['Mparams'] = [f/(10**6) for f in saved_stats['params']]
118 |
119 | json_save(os.path.join(final_output_dir, 'stats.json'), saved_stats)
120 |
121 |
122 | if distributed:
123 | torch.cuda.set_device(args.local_rank)
124 | torch.distributed.init_process_group(
125 | backend="nccl", init_method="env://",
126 | )
127 |
128 | crop_size = (config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
129 | train_dataset = eval('datasets.'+config.DATASET.DATASET)(
130 | root=config.DATASET.ROOT,
131 | list_path=config.DATASET.TRAIN_SET,
132 | num_samples=None,
133 | num_classes=config.DATASET.NUM_CLASSES,
134 | multi_scale=config.TRAIN.MULTI_SCALE,
135 | flip=config.TRAIN.FLIP,
136 | ignore_label=config.TRAIN.IGNORE_LABEL,
137 | base_size=config.TRAIN.BASE_SIZE,
138 | crop_size=crop_size,
139 | downsample_rate=config.TRAIN.DOWNSAMPLERATE,
140 | scale_factor=config.TRAIN.SCALE_FACTOR)
141 |
142 | if distributed:
143 | train_sampler = DistributedSampler(train_dataset)
144 | else:
145 | train_sampler = None
146 |
147 | trainloader = torch.utils.data.DataLoader(
148 | train_dataset,
149 | batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
150 | shuffle=config.TRAIN.SHUFFLE and train_sampler is None,
151 | num_workers=config.WORKERS,
152 | pin_memory=True,
153 | drop_last=True,
154 | sampler=train_sampler)
155 |
156 | if config.DATASET.EXTRA_TRAIN_SET:
157 | extra_train_dataset = eval('datasets.'+config.DATASET.DATASET)(
158 | root=config.DATASET.ROOT,
159 | list_path=config.DATASET.EXTRA_TRAIN_SET,
160 | num_samples=None,
161 | num_classes=config.DATASET.NUM_CLASSES,
162 | multi_scale=config.TRAIN.MULTI_SCALE,
163 | flip=config.TRAIN.FLIP,
164 | ignore_label=config.TRAIN.IGNORE_LABEL,
165 | base_size=config.TRAIN.BASE_SIZE,
166 | crop_size=crop_size,
167 | downsample_rate=config.TRAIN.DOWNSAMPLERATE,
168 | scale_factor=config.TRAIN.SCALE_FACTOR)
169 |
170 | if distributed:
171 | extra_train_sampler = DistributedSampler(extra_train_dataset)
172 | else:
173 | extra_train_sampler = None
174 |
175 | extra_trainloader = torch.utils.data.DataLoader(
176 | extra_train_dataset,
177 | batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
178 | shuffle=config.TRAIN.SHUFFLE and extra_train_sampler is None,
179 | num_workers=config.WORKERS,
180 | pin_memory=True,
181 | drop_last=True,
182 | sampler=extra_train_sampler)
183 |
184 | test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
185 | test_dataset = eval('datasets.'+config.DATASET.DATASET)(
186 | root=config.DATASET.ROOT,
187 | list_path=config.DATASET.TEST_SET,
188 | num_samples=config.TEST.NUM_SAMPLES,
189 | num_classes=config.DATASET.NUM_CLASSES,
190 | multi_scale=False,
191 | flip=False,
192 | ignore_label=config.TRAIN.IGNORE_LABEL,
193 | base_size=config.TEST.BASE_SIZE,
194 | crop_size=test_size,
195 | center_crop_test=config.TEST.CENTER_CROP_TEST,
196 | downsample_rate=1)
197 |
198 | if distributed:
199 | test_sampler = DistributedSampler(test_dataset)
200 | else:
201 | test_sampler = None
202 |
203 | testloader = torch.utils.data.DataLoader(
204 | test_dataset,
205 | batch_size=config.TEST.BATCH_SIZE_PER_GPU,
206 | shuffle=False,
207 | num_workers=config.WORKERS,
208 | pin_memory=True,
209 | sampler=test_sampler)
210 |
211 | criterion = CrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
212 | weight=train_dataset.class_weights)
213 |
214 | model = FullEEModel(model, criterion, config=config)
215 |
216 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
217 | model = model.to(device)
218 | model = nn.parallel.DistributedDataParallel(
219 | model, device_ids=[args.local_rank], output_device=args.local_rank)
220 |
221 | if config.TRAIN.OPTIMIZER == 'sgd':
222 | if config.TRAIN.ALLE_ONLY:
223 | param = [
224 | {'params': model.module.model.exit1.parameters(), 'lr': config.TRAIN.EXTRA_LR},
225 | {'params': model.module.model.exit2.parameters(), 'lr': config.TRAIN.EXTRA_LR},
226 | {'params': model.module.model.exit3.parameters(), 'lr': config.TRAIN.EXTRA_LR},
227 | {'params': model.module.model.last_layer.parameters(), 'lr': config.TRAIN.EXTRA_LR},
228 | ]
229 | elif config.TRAIN.EE_ONLY:
230 | param = [
231 | {'params': model.module.model.exit1.parameters(), 'lr': config.TRAIN.EXTRA_LR},
232 | {'params': model.module.model.exit2.parameters(), 'lr': config.TRAIN.EXTRA_LR},
233 | {'params': model.module.model.exit3.parameters(), 'lr': config.TRAIN.EXTRA_LR}
234 | ]
235 | else:
236 | param = [
237 | {'params':
238 | filter(lambda p: p.requires_grad,
239 | model.parameters()),
240 | 'lr': config.TRAIN.LR}
241 | ]
242 |
243 | optimizer = torch.optim.SGD(param,
244 | lr=config.TRAIN.LR,
245 | momentum=config.TRAIN.MOMENTUM,
246 | weight_decay=config.TRAIN.WD,
247 | nesterov=config.TRAIN.NESTEROV,
248 | )
249 | else:
250 | raise ValueError('Only Support SGD optimizer')
251 |
252 | epoch_iters = np.int(train_dataset.__len__() /
253 | config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))
254 | best_mIoU = 0
255 | last_epoch = 0
256 | if config.TRAIN.RESUME:
257 | if config.DATASET.EXTRA_TRAIN_SET:
258 | model_state_file = os.path.join(config.RESUME_DIR, 'checkpoint.pth.tar')
259 | assert os.path.isfile(model_state_file)
260 | load_optimizer_dict = False
261 | else:
262 | model_state_file = os.path.join(final_output_dir,
263 | 'checkpoint.pth.tar')
264 |
265 |
266 | if os.path.isfile(model_state_file):
267 | checkpoint = torch.load(model_state_file,
268 | map_location=lambda storage, loc: storage)
269 | best_mIoU = checkpoint['best_mIoU']
270 | last_epoch = checkpoint['epoch']
271 | model.module.load_state_dict(checkpoint['state_dict'])
272 | if not config.DATASET.EXTRA_TRAIN_SET:
273 | optimizer.load_state_dict(checkpoint['optimizer'])
274 | logger.info("=> loaded checkpoint (epoch {})"
275 | .format(checkpoint['epoch']))
276 |
277 |
278 | start = timeit.default_timer()
279 | end_epoch = config.TRAIN.END_EPOCH + config.TRAIN.EXTRA_EPOCH
280 | num_iters = config.TRAIN.END_EPOCH * epoch_iters
281 | extra_iters = config.TRAIN.EXTRA_EPOCH * epoch_iters
282 |
283 | logger.info('Starting training at rank {}'.format(args.local_rank))
284 | for epoch in range(last_epoch, end_epoch):
285 | if distributed:
286 | train_sampler.set_epoch(epoch)
287 | if epoch >= config.TRAIN.END_EPOCH:
288 | train_ee(config, epoch-config.TRAIN.END_EPOCH,
289 | config.TRAIN.EXTRA_EPOCH, epoch_iters,
290 | config.TRAIN.EXTRA_LR, extra_iters,
291 | extra_trainloader, optimizer, model,
292 | writer_dict, device)
293 | else:
294 | train_ee(config, epoch, config.TRAIN.END_EPOCH,
295 | epoch_iters, config.TRAIN.LR, num_iters,
296 | trainloader, optimizer, model, writer_dict,
297 | device)
298 |
299 | if args.local_rank == 0:
300 | logger.info('=> saving checkpoint to {}'.format(
301 | final_output_dir + 'checkpoint.pth.tar'))
302 | torch.save({
303 | 'epoch': epoch+1,
304 | 'best_mIoU': best_mIoU,
305 | 'state_dict': model.module.state_dict(),
306 | 'optimizer': optimizer.state_dict(),
307 | }, os.path.join(final_output_dir,'checkpoint.pth.tar'))
308 |
309 |
310 | torch.save(model.module.state_dict(), os.path.join(final_output_dir,'checkpoint.pth'))
311 |
312 |
313 | if epoch == end_epoch - 1:
314 | torch.save(model.module.state_dict(),
315 | os.path.join(final_output_dir, 'final_state.pth'))
316 |
317 | writer_dict['writer'].close()
318 | end = timeit.default_timer()
319 | logger.info('Hours: {}'.format((end-start)/3600))
320 | logger.info('Done')
321 |
322 |
323 | pid = os.getpid()
324 | torch.cuda.empty_cache()
325 | devices = os.environ['CUDA_VISIBLE_DEVICES']
326 | device = devices.split(',')[1]
327 | command = f'CUDA_VISIBLE_DEVICES={device} python tools/test_ee.py --cfg {final_output_dir}/config.yaml'
328 | print(command)
329 |
330 | subprocess.run(command, shell=True)
331 |
332 | if __name__ == '__main__':
333 | main()
334 |
335 |
336 |
337 |
--------------------------------------------------------------------------------