├── .gitignore
├── LICENSE
├── README.md
├── data
├── __init__.py
├── base_data_loader.py
├── base_dataset.py
└── pothole_dataset.py
├── datasets
└── palette.txt
├── doc
├── GAL-DeepLabv3+.png
└── GAL.png
├── models
├── __init__.py
├── base_model.py
├── galnet_model.py
├── include
│ ├── __init__.py
│ └── deeplabv3plus_inc
│ │ ├── __init__.py
│ │ └── modeling
│ │ ├── __init__.py
│ │ ├── aspp.py
│ │ ├── backbone
│ │ ├── __init__.py
│ │ └── resnet.py
│ │ ├── decoder.py
│ │ └── sync_batchnorm
│ │ ├── __init__.py
│ │ ├── batchnorm.py
│ │ ├── comm.py
│ │ └── replicate.py
└── networks.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── scripts
├── test_gal.sh
└── train_gal.sh
├── test.py
├── train.py
└── util
├── __init__.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Rui (Ranger) Fan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GAL-DeepLabv3Plus
2 |
3 | ## Introduction
4 |
5 | This is the official PyTorch implementation of **[Graph Attention Layer Evolves Semantic Segmentation for Road Pothole Detection: A Benchmark and Algorithms](https://ieeexplore.ieee.org/document/9547682)**, published on IEEE T-IP in 2021.
6 |
7 | In this repository, we provide the training and testing setups on the [pothole dataset](https://drive.google.com/file/d/1ofp-44LnYTDByOuVMOc2hBrUCUjncg3k/view?usp=sharing) ([paper](https://ieeexplore.ieee.org/abstract/document/8809907)). We have tested our code in Python 3.8.10, CUDA 11.1, and PyTorch 1.10.1.
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 | ## Setup
18 |
19 | Please setup the pothole dataset and the pretrained weight according to the following folder structure:
20 |
21 | ```
22 | GAL-DeepLabv3plus
23 | |-- data
24 | |-- datasets
25 | | |-- pothole
26 | |-- models
27 | |-- options
28 | |-- runs
29 | | |-- tdisp_gal
30 | ...
31 | ```
32 |
33 | The pothole dataset `datasets/pothole` can be downloaded from [here](https://drive.google.com/file/d/1ofp-44LnYTDByOuVMOc2hBrUCUjncg3k/view?usp=sharing), and the pretrained weight `runs/tdisp_gal` for our GAL-DeepLabv3+ can be downloaded from [here](https://drive.google.com/file/d/1wmgPUymOOPUWovwIwLdIg4hf0jsWyYja/view?usp=sharing).
34 |
35 | ## Usage
36 |
37 | ### Testing on the Pothole Dataset
38 |
39 | For testing, please first setup the `runs/tdisp_gal` and the `datasets/pothole` folders as mentioned above. Then, run the following script:
40 |
41 | ```
42 | bash ./scripts/test_gal.sh
43 | ```
44 |
45 | to test GAL-DeepLabv3+ with the transformed disparity images. The prediction results are stored in `testresults`.
46 |
47 | ### Training on the Pothole Dataset
48 |
49 | For training, please first setup the `datasets/pothole` folder as mentioned above. Then, run the following script:
50 |
51 | ```
52 | bash ./scripts/train_gal.sh
53 | ```
54 |
55 | to train GAL-DeepLabv3+ with the transformed disparity images. The weights and the tensorboard record containing the loss curves as well as the performance on the validation set will be saved in `runs`.
56 |
57 | ## Citation
58 |
59 | If you use this code for your research, please cite our paper:
60 |
61 | ```
62 | @article{fan2021graph,
63 | title = {Graph Attention Layer Evolves Semantic Segmentation for Road Pothole Detection: A Benchmark and Algorithms},
64 | author = {Fan, Rui and Wang, Hengli and Wang, Yuan and Liu, Ming and Pitas, Ioannis},
65 | journal = {IEEE Transactions on Image Processing},
66 | volume = {30},
67 | number = {},
68 | pages = {8144-8154},
69 | year = {2021},
70 | publisher = {IEEE},
71 | doi = {10.1109/TIP.2021.3112316}
72 | }
73 | ```
74 | If you use the pothole dataset for your research, please cite our papers:
75 |
76 | ```
77 | @article{fan2019pothole,
78 | title={Pothole detection based on disparity transformation and road surface modeling},
79 | author={Fan, Rui and Ozgunalp, Umar and Hosking, Brett and Liu, Ming and Pitas, Ioannis},
80 | journal={IEEE Transactions on Image Processing},
81 | volume={29},
82 | pages={897--908},
83 | year={2019},
84 | publisher={IEEE}
85 | }
86 | @article{fan2019road,
87 | title={Road damage detection based on unsupervised disparity map segmentation},
88 | author={Fan, Rui and Liu, Ming},
89 | journal={IEEE Transactions on Intelligent Transportation Systems},
90 | volume={21},
91 | number={11},
92 | pages={4906--4911},
93 | year={2019},
94 | publisher={IEEE}
95 | }
96 | @article{fan2018road,
97 | title={Road surface 3D reconstruction based on dense subpixel disparity map estimation},
98 | author={Fan, Rui and Ai, Xiao and Dahnoun, Naim},
99 | journal={IEEE Transactions on Image Processing},
100 | volume={27},
101 | number={6},
102 | pages={3025--3035},
103 | year={2018},
104 | publisher={IEEE}
105 | }
106 | ```
107 |
108 |
109 | ## Acknowledgement
110 |
111 | Our code is inspired by [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix), [pytorch_segmentation](https://github.com/yassouali/pytorch_segmentation), [pytorch-deeplab-xception
112 | ](https://github.com/jfzhang95/pytorch-deeplab-xception), and [RTFNet](https://github.com/yuxiangsun/RTFNet).
113 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import torch.utils.data
3 | from data.base_data_loader import BaseDataLoader
4 | from data.base_dataset import BaseDataset
5 | import numpy
6 |
7 |
8 | def find_dataset_using_name(dataset_name):
9 | # Given the option --dataset [datasetname],
10 | # the file "data/datasetname_dataset.py"
11 | # will be imported.
12 | dataset_filename = "data." + dataset_name + "_dataset"
13 | datasetlib = importlib.import_module(dataset_filename)
14 |
15 | # In the file, the class called DatasetNameDataset() will
16 | # be instantiated. It has to be a subclass of BaseDataset,
17 | # and it is case-insensitive.
18 | dataset = None
19 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
20 | for name, cls in datasetlib.__dict__.items():
21 | if name.lower() == target_dataset_name.lower() \
22 | and issubclass(cls, BaseDataset):
23 | dataset = cls
24 |
25 | if dataset is None:
26 | print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
27 | exit(0)
28 |
29 | return dataset
30 |
31 | def get_option_setter(dataset_name):
32 | dataset_class = find_dataset_using_name(dataset_name)
33 | return dataset_class.modify_commandline_options
34 |
35 | def create_dataset(opt):
36 | dataset = find_dataset_using_name(opt.dataset)
37 | instance = dataset()
38 | instance.initialize(opt)
39 | print("dataset [%s] was created" % (instance.name()))
40 | return instance
41 |
42 | def CreateDataLoader(opt):
43 | data_loader = CustomDatasetDataLoader()
44 | data_loader.initialize(opt)
45 | return data_loader
46 |
47 |
48 | # Wrapper class of Dataset class that performs
49 | # multi-threaded data loading
50 | class CustomDatasetDataLoader(BaseDataLoader):
51 | def name(self):
52 | return 'CustomDatasetDataLoader'
53 |
54 | def initialize(self, opt):
55 | BaseDataLoader.initialize(self, opt)
56 | self.dataset = create_dataset(opt)
57 | self.dataloader = torch.utils.data.DataLoader(
58 | self.dataset,
59 | batch_size=opt.batch_size,
60 | shuffle=not opt.serial_batches,
61 | num_workers=int(opt.num_threads),
62 | drop_last=True,
63 | worker_init_fn=lambda worker_id: numpy.random.seed(opt.seed + worker_id))
64 |
65 | def load_data(self):
66 | return self
67 |
68 | def __len__(self):
69 | return len(self.dataset)
70 |
71 | def __iter__(self):
72 | for i, data in enumerate(self.dataloader):
73 | yield data
74 |
--------------------------------------------------------------------------------
/data/base_data_loader.py:
--------------------------------------------------------------------------------
1 | class BaseDataLoader():
2 | def __init__(self):
3 | pass
4 |
5 | def initialize(self, opt):
6 | self.opt = opt
7 | pass
8 |
9 | def load_data():
10 | return None
11 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 |
4 | class BaseDataset(data.Dataset):
5 | def __init__(self):
6 | super(BaseDataset, self).__init__()
7 |
8 | def name(self):
9 | return 'BaseDataset'
10 |
11 | @staticmethod
12 | def modify_commandline_options(parser, is_train):
13 | return parser
14 |
15 | def initialize(self, opt):
16 | pass
17 |
18 | def __len__(self):
19 | return 0
20 |
--------------------------------------------------------------------------------
/data/pothole_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import torchvision.transforms as transforms
3 | import torch
4 | import cv2
5 | import numpy as np
6 | from data.base_dataset import BaseDataset
7 |
8 |
9 | class potholedataset(BaseDataset):
10 | """dataloader for pothole dataset"""
11 | @staticmethod
12 | def modify_commandline_options(parser, is_train):
13 | return parser
14 |
15 | def initialize(self, opt):
16 | self.opt = opt
17 | self.batch_size = opt.batch_size
18 | self.num_labels = 2
19 |
20 | if opt.phase == "train":
21 | self.image_list = np.arange(1, 43)
22 | elif opt.phase == "val":
23 | self.image_list = np.arange(43, 56)
24 | else:
25 | self.image_list = np.arange(43, 56)
26 |
27 | def __getitem__(self, index):
28 | base_dir = "./datasets/pothole"
29 | name = str(self.image_list[index]).zfill(2) + ".png"
30 |
31 | rgb_image = cv2.cvtColor(cv2.imread(os.path.join(base_dir, 'rgb', name)), cv2.COLOR_BGR2RGB)
32 | tdisp_image = cv2.imread(os.path.join(base_dir, 'tdisp', name), cv2.IMREAD_ANYDEPTH)
33 | label_image = cv2.cvtColor(cv2.imread(os.path.join(base_dir, 'label', name)), cv2.COLOR_BGR2RGB)
34 |
35 | label = np.zeros((label_image.shape[0], label_image.shape[1]), dtype=np.uint8)
36 | label[label_image[:, :, 0] > 0] = 1
37 |
38 | rgb_image = rgb_image.astype(np.float32) / 255
39 | tdisp_image = tdisp_image.astype(np.float32) / 65535
40 | rgb_image = transforms.ToTensor()(rgb_image)
41 | tdisp_image = transforms.ToTensor()(tdisp_image)
42 | label = torch.from_numpy(label)
43 | label = label.type(torch.LongTensor)
44 |
45 | # return a dictionary containing useful information
46 | # input rgb images, tdisp images, and labels for training;
47 | # 'path': image name for saving predictions
48 | return {'rgb_image': rgb_image, 'tdisp_image': tdisp_image, 'label': label,
49 | 'path': name}
50 |
51 | def __len__(self):
52 | return len(self.image_list)
53 |
54 | def name(self):
55 | return 'pothole'
56 |
--------------------------------------------------------------------------------
/datasets/palette.txt:
--------------------------------------------------------------------------------
1 | 0 0 0
2 | 128 0 0
3 | 0 128 0
4 | 128 128 0
5 | 0 0 128
6 | 128 0 128
7 | 0 128 128
8 | 128 128 128
9 | 64 0 0
10 | 192 0 0
11 | 64 128 0
12 | 192 128 0
13 | 64 0 128
14 | 192 0 128
15 | 64 128 128
16 | 192 128 128
17 | 0 64 0
18 | 128 64 0
19 | 0 192 0
20 | 128 192 0
21 | 0 64 128
22 | 128 64 128
23 | 0 192 128
24 | 128 192 128
25 | 64 64 0
26 | 192 64 0
27 | 64 192 0
28 | 192 192 0
29 | 64 64 128
30 | 192 64 128
31 | 64 192 128
32 | 192 192 128
33 | 0 0 64
34 | 128 0 64
35 | 0 128 64
36 | 128 128 64
37 | 0 0 192
38 | 128 0 192
39 | 0 128 192
40 | 128 128 192
41 | 64 0 64
42 | 192 0 64
43 | 64 128 64
44 | 192 128 64
45 | 64 0 192
46 | 192 0 192
47 | 64 128 192
48 | 192 128 192
49 | 0 64 64
50 | 128 64 64
51 | 0 192 64
52 | 128 192 64
53 | 0 64 192
54 | 128 64 192
55 | 0 192 192
56 | 128 192 192
57 | 64 64 64
58 | 192 64 64
59 | 64 192 64
60 | 192 192 64
61 | 64 64 192
62 | 192 64 192
63 | 64 192 192
64 | 192 192 192
65 | 32 0 0
66 | 160 0 0
67 | 32 128 0
68 | 160 128 0
69 | 32 0 128
70 | 160 0 128
71 | 32 128 128
72 | 160 128 128
73 | 96 0 0
74 | 224 0 0
75 | 96 128 0
76 | 224 128 0
77 | 96 0 128
78 | 224 0 128
79 | 96 128 128
80 | 224 128 128
81 | 32 64 0
82 | 160 64 0
83 | 32 192 0
84 | 160 192 0
85 | 32 64 128
86 | 160 64 128
87 | 32 192 128
88 | 160 192 128
89 | 96 64 0
90 | 224 64 0
91 | 96 192 0
92 | 224 192 0
93 | 96 64 128
94 | 224 64 128
95 | 96 192 128
96 | 224 192 128
97 | 32 0 64
98 | 160 0 64
99 | 32 128 64
100 | 160 128 64
101 | 32 0 192
102 | 160 0 192
103 | 32 128 192
104 | 160 128 192
105 | 96 0 64
106 | 224 0 64
107 | 96 128 64
108 | 224 128 64
109 | 96 0 192
110 | 224 0 192
111 | 96 128 192
112 | 224 128 192
113 | 32 64 64
114 | 160 64 64
115 | 32 192 64
116 | 160 192 64
117 | 32 64 192
118 | 160 64 192
119 | 32 192 192
120 | 160 192 192
121 | 96 64 64
122 | 224 64 64
123 | 96 192 64
124 | 224 192 64
125 | 96 64 192
126 | 224 64 192
127 | 96 192 192
128 | 224 192 192
129 | 0 32 0
130 | 128 32 0
131 | 0 160 0
132 | 128 160 0
133 | 0 32 128
134 | 128 32 128
135 | 0 160 128
136 | 128 160 128
137 | 64 32 0
138 | 192 32 0
139 | 64 160 0
140 | 192 160 0
141 | 64 32 128
142 | 192 32 128
143 | 64 160 128
144 | 192 160 128
145 | 0 96 0
146 | 128 96 0
147 | 0 224 0
148 | 128 224 0
149 | 0 96 128
150 | 128 96 128
151 | 0 224 128
152 | 128 224 128
153 | 64 96 0
154 | 192 96 0
155 | 64 224 0
156 | 192 224 0
157 | 64 96 128
158 | 192 96 128
159 | 64 224 128
160 | 192 224 128
161 | 0 32 64
162 | 128 32 64
163 | 0 160 64
164 | 128 160 64
165 | 0 32 192
166 | 128 32 192
167 | 0 160 192
168 | 128 160 192
169 | 64 32 64
170 | 192 32 64
171 | 64 160 64
172 | 192 160 64
173 | 64 32 192
174 | 192 32 192
175 | 64 160 192
176 | 192 160 192
177 | 0 96 64
178 | 128 96 64
179 | 0 224 64
180 | 128 224 64
181 | 0 96 192
182 | 128 96 192
183 | 0 224 192
184 | 128 224 192
185 | 64 96 64
186 | 192 96 64
187 | 64 224 64
188 | 192 224 64
189 | 64 96 192
190 | 192 96 192
191 | 64 224 192
192 | 192 224 192
193 | 32 32 0
194 | 160 32 0
195 | 32 160 0
196 | 160 160 0
197 | 32 32 128
198 | 160 32 128
199 | 32 160 128
200 | 160 160 128
201 | 96 32 0
202 | 224 32 0
203 | 96 160 0
204 | 224 160 0
205 | 96 32 128
206 | 224 32 128
207 | 96 160 128
208 | 224 160 128
209 | 32 96 0
210 | 160 96 0
211 | 32 224 0
212 | 160 224 0
213 | 32 96 128
214 | 160 96 128
215 | 32 224 128
216 | 160 224 128
217 | 96 96 0
218 | 224 96 0
219 | 96 224 0
220 | 224 224 0
221 | 96 96 128
222 | 224 96 128
223 | 96 224 128
224 | 224 224 128
225 | 32 32 64
226 | 160 32 64
227 | 32 160 64
228 | 160 160 64
229 | 32 32 192
230 | 160 32 192
231 | 32 160 192
232 | 160 160 192
233 | 96 32 64
234 | 224 32 64
235 | 96 160 64
236 | 224 160 64
237 | 96 32 192
238 | 224 32 192
239 | 96 160 192
240 | 224 160 192
241 | 32 96 64
242 | 160 96 64
243 | 32 224 64
244 | 160 224 64
245 | 32 96 192
246 | 160 96 192
247 | 32 224 192
248 | 160 224 192
249 | 96 96 64
250 | 224 96 64
251 | 96 224 64
252 | 224 224 64
253 | 96 96 192
254 | 224 96 192
255 | 96 224 192
256 | 224 224 192
257 |
--------------------------------------------------------------------------------
/doc/GAL-DeepLabv3+.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruirangerfan/GAL-DeepLabv3Plus/da613d0907ebf2908978a08f72b1a58e23caafb9/doc/GAL-DeepLabv3+.png
--------------------------------------------------------------------------------
/doc/GAL.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruirangerfan/GAL-DeepLabv3Plus/da613d0907ebf2908978a08f72b1a58e23caafb9/doc/GAL.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from models.base_model import BaseModel
3 |
4 |
5 | def find_model_using_name(model_name):
6 | # Given the option --model [modelname],
7 | # the file "models/modelname_model.py"
8 | # will be imported.
9 | model_filename = "models." + model_name + "_model"
10 | modellib = importlib.import_module(model_filename)
11 |
12 | # In the file, the class called ModelNameModel() will
13 | # be instantiated. It has to be a subclass of BaseModel,
14 | # and it is case-insensitive.
15 | model = None
16 | target_model_name = model_name.replace('_', '') + 'model'
17 | for name, cls in modellib.__dict__.items():
18 | if name.lower() == target_model_name.lower() \
19 | and issubclass(cls, BaseModel):
20 | model = cls
21 |
22 | if model is None:
23 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
24 | exit(0)
25 |
26 | return model
27 |
28 | def get_option_setter(model_name):
29 | model_class = find_model_using_name(model_name)
30 | return model_class.modify_commandline_options
31 |
32 | def create_model(opt, dataset):
33 | model = find_model_using_name(opt.model)
34 | instance = model()
35 | instance.initialize(opt, dataset)
36 | print("model [%s] was created" % (instance.name()))
37 | return instance
38 |
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from . import networks
5 |
6 |
7 | class BaseModel():
8 | # modify parser to add command line options,
9 | # and also change the default values if needed
10 | @staticmethod
11 | def modify_commandline_options(parser, is_train):
12 | return parser
13 |
14 | def name(self):
15 | return 'BaseModel'
16 |
17 | def initialize(self, opt):
18 | self.opt = opt
19 | self.gpu_ids = opt.gpu_ids
20 | self.isTrain = opt.isTrain
21 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
22 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
23 | self.loss_names = []
24 | self.model_names = []
25 | self.visual_names = []
26 | self.image_names = []
27 | self.image_oriSize = []
28 |
29 | def set_input(self, input):
30 | self.input = input
31 |
32 | def forward(self):
33 | pass
34 |
35 | # load and print networks; create schedulers
36 | def setup(self, opt, parser=None):
37 | if self.isTrain:
38 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
39 |
40 | if not self.isTrain or opt.continue_train:
41 | self.load_networks(opt.epoch)
42 | self.print_networks(opt.verbose)
43 |
44 | # make models eval mode during test time
45 | def eval(self):
46 | for name in self.model_names:
47 | if isinstance(name, str):
48 | net = getattr(self, 'net' + name)
49 | net.eval()
50 |
51 | def train(self):
52 | for name in self.model_names:
53 | if isinstance(name, str):
54 | net = getattr(self, 'net' + name)
55 | net.train()
56 |
57 | # used in test time, wrapping `forward` in no_grad() so we don't save
58 | # intermediate steps for backprop
59 | def test(self):
60 | with torch.no_grad():
61 | self.forward()
62 |
63 | # get image names
64 | def get_image_names(self):
65 | return self.image_names
66 |
67 | def optimize_parameters(self):
68 | pass
69 |
70 | # update learning rate (called once every epoch)
71 | def update_learning_rate(self):
72 | for scheduler in self.schedulers:
73 | scheduler.step()
74 | lr = self.optimizers[0].param_groups[0]['lr']
75 | print('learning rate = %.7f' % lr)
76 |
77 | # return visualization images. train.py will display these images in tensorboardX
78 | def get_current_visuals(self):
79 | visual_ret = OrderedDict()
80 | for name in self.visual_names:
81 | if isinstance(name, str):
82 | visual_ret[name] = getattr(self, name)
83 | return visual_ret
84 |
85 | # return traning losses/errors. train.py will print out these errors as debugging information
86 | def get_current_losses(self):
87 | errors_ret = OrderedDict()
88 | for name in self.loss_names:
89 | if isinstance(name, str):
90 | # float(...) works for both scalar tensor and float number
91 | errors_ret[name] = float(getattr(self, 'loss_' + name))
92 | return errors_ret
93 |
94 | # save models to the disk
95 | def save_networks(self, epoch):
96 | for name in self.model_names:
97 | if isinstance(name, str):
98 | save_filename = '%s_net_%s.pth' % (epoch, name)
99 | save_path = os.path.join(self.save_dir, save_filename)
100 | net = getattr(self, 'net' + name)
101 |
102 | if len(self.gpu_ids) > 0 and torch.cuda.is_available():
103 | torch.save(net.module.cpu().state_dict(), save_path)
104 | net.cuda(self.gpu_ids[0])
105 | else:
106 | torch.save(net.cpu().state_dict(), save_path)
107 |
108 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
109 | key = keys[i]
110 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
111 | if module.__class__.__name__.startswith('InstanceNorm') and \
112 | (key == 'running_mean' or key == 'running_var'):
113 | if getattr(module, key) is None:
114 | state_dict.pop('.'.join(keys))
115 | if module.__class__.__name__.startswith('InstanceNorm') and \
116 | (key == 'num_batches_tracked'):
117 | state_dict.pop('.'.join(keys))
118 | else:
119 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
120 |
121 | # load models from the disk
122 | def load_networks(self, epoch):
123 | for name in self.model_names:
124 | if isinstance(name, str):
125 | load_filename = '%s_net_%s.pth' % (epoch, name)
126 | load_path = os.path.join(self.save_dir, load_filename)
127 | net = getattr(self, 'net' + name)
128 | if isinstance(net, torch.nn.DataParallel):
129 | net = net.module
130 | print('loading the model from %s' % load_path)
131 | # if you are using PyTorch newer than 0.4 (e.g., built from
132 | # GitHub source), you can remove str() on self.device
133 | state_dict = torch.load(load_path, map_location=str(self.device))
134 | if hasattr(state_dict, '_metadata'):
135 | del state_dict._metadata
136 |
137 | for key in list(state_dict.keys()):
138 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
139 | net.load_state_dict(state_dict)
140 |
141 | # print network information
142 | def print_networks(self, verbose):
143 | print('---------- Networks initialized -------------')
144 | for name in self.model_names:
145 | if isinstance(name, str):
146 | net = getattr(self, 'net' + name)
147 | num_params = 0
148 | for param in net.parameters():
149 | num_params += param.numel()
150 | if verbose:
151 | print(net)
152 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
153 | print('-----------------------------------------------')
154 |
155 | # set requies_grad=Fasle to avoid computation
156 | def set_requires_grad(self, nets, requires_grad=False):
157 | if not isinstance(nets, list):
158 | nets = [nets]
159 | for net in nets:
160 | if net is not None:
161 | for param in net.parameters():
162 | param.requires_grad = requires_grad
163 |
--------------------------------------------------------------------------------
/models/galnet_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base_model import BaseModel
3 | from . import networks
4 |
5 |
6 | class GALNetModel(BaseModel):
7 | def name(self):
8 | return 'GALNet'
9 |
10 | @staticmethod
11 | def modify_commandline_options(parser, is_train=True):
12 | # changing the default values
13 | if is_train:
14 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
15 | return parser
16 |
17 | def initialize(self, opt, dataset):
18 | BaseModel.initialize(self, opt)
19 | self.isTrain = opt.isTrain
20 | # specify the training losses you want to print out. The program will call base_model.get_current_losses
21 | self.loss_names = ['segmentation']
22 | # specify the images you want to save/display. The program will call base_model.get_current_visuals
23 | self.visual_names = ['rgb_image', 'tdisp_image', 'label', 'output']
24 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
25 | self.model_names = ['GALNet']
26 |
27 | # load/define networks
28 | if opt.input == "rgb":
29 | print("Using RGB images as input")
30 | self.input_channels = 3
31 | elif opt.input == "tdisp":
32 | print("Using transformed disparity images as input")
33 | self.input_channels = 1
34 | else:
35 | raise NotImplementedError
36 |
37 | self.netGALNet = networks.define_GALNet(dataset.num_labels, gpu_ids= self.gpu_ids, input_channels= self.input_channels, use_gal=opt.gal)
38 | # define loss functions
39 | self.criterionSegmentation = networks.SegmantationLoss(class_weights=None).to(self.device)
40 |
41 | if self.isTrain:
42 | # initialize optimizers
43 | self.optimizers = []
44 | self.optimizer = torch.optim.SGD(self.netGALNet.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay)
45 | self.optimizers.append(self.optimizer)
46 | self.set_requires_grad(self.netGALNet, True)
47 |
48 | def set_input(self, input):
49 | self.rgb_image = input['rgb_image'].to(self.device)
50 | self.tdisp_image = input['tdisp_image'].to(self.device)
51 | self.label = input['label'].to(self.device)
52 | self.image_names = input['path']
53 |
54 | def forward(self):
55 | if self.opt.input == "rgb":
56 | self.output = self.netGALNet(self.rgb_image)
57 | elif self.opt.input == "tdisp":
58 | self.output = self.netGALNet(self.tdisp_image)
59 | else:
60 | raise NotImplementedError
61 |
62 | def get_loss(self):
63 | self.loss_segmentation = self.criterionSegmentation(self.output, self.label)
64 |
65 | def backward(self):
66 | self.loss_segmentation.backward()
67 |
68 | def optimize_parameters(self):
69 | self.forward()
70 | self.optimizer.zero_grad()
71 | self.get_loss()
72 | self.backward()
73 | self.optimizer.step()
74 |
--------------------------------------------------------------------------------
/models/include/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruirangerfan/GAL-DeepLabv3Plus/da613d0907ebf2908978a08f72b1a58e23caafb9/models/include/__init__.py
--------------------------------------------------------------------------------
/models/include/deeplabv3plus_inc/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruirangerfan/GAL-DeepLabv3Plus/da613d0907ebf2908978a08f72b1a58e23caafb9/models/include/deeplabv3plus_inc/__init__.py
--------------------------------------------------------------------------------
/models/include/deeplabv3plus_inc/modeling/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruirangerfan/GAL-DeepLabv3Plus/da613d0907ebf2908978a08f72b1a58e23caafb9/models/include/deeplabv3plus_inc/modeling/__init__.py
--------------------------------------------------------------------------------
/models/include/deeplabv3plus_inc/modeling/aspp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5 |
6 |
7 | class _ASPPModule(nn.Module):
8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
9 | super(_ASPPModule, self).__init__()
10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
11 | stride=1, padding=padding, dilation=dilation, bias=False)
12 | self.bn = BatchNorm(planes)
13 | self.relu = nn.ReLU()
14 |
15 | self._init_weight()
16 |
17 | def forward(self, x):
18 | x = self.atrous_conv(x)
19 | x = self.bn(x)
20 |
21 | return self.relu(x)
22 |
23 | def _init_weight(self):
24 | for m in self.modules():
25 | if isinstance(m, nn.Conv2d):
26 | torch.nn.init.kaiming_normal_(m.weight)
27 | elif isinstance(m, SynchronizedBatchNorm2d):
28 | m.weight.data.fill_(1)
29 | m.bias.data.zero_()
30 | elif isinstance(m, nn.BatchNorm2d):
31 | m.weight.data.fill_(1)
32 | m.bias.data.zero_()
33 |
34 | class ASPP(nn.Module):
35 | def __init__(self, backbone, output_stride, BatchNorm):
36 | super(ASPP, self).__init__()
37 | if backbone == 'drn':
38 | inplanes = 512
39 | elif backbone == 'mobilenet':
40 | inplanes = 320
41 | else:
42 | inplanes = 2048
43 | if output_stride == 16:
44 | dilations = [1, 6, 12, 18]
45 | elif output_stride == 8:
46 | dilations = [1, 12, 24, 36]
47 | else:
48 | raise NotImplementedError
49 |
50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)
54 |
55 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
56 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
57 | BatchNorm(256),
58 | nn.ReLU())
59 |
60 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
61 | self.bn1 = BatchNorm(256)
62 | self.relu = nn.ReLU()
63 | self.dropout = nn.Dropout(0.5)
64 | self._init_weight()
65 |
66 | def forward(self, x):
67 | x1 = self.aspp1(x)
68 | x2 = self.aspp2(x)
69 | x3 = self.aspp3(x)
70 | x4 = self.aspp4(x)
71 | x5 = self.global_avg_pool(x)
72 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
73 | x = torch.cat((x1, x2, x3, x4, x5), dim=1)
74 |
75 | x = self.conv1(x)
76 | x = self.bn1(x)
77 | x = self.relu(x)
78 |
79 | return self.dropout(x)
80 |
81 | def _init_weight(self):
82 | for m in self.modules():
83 | if isinstance(m, nn.Conv2d):
84 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
85 | # m.weight.data.normal_(0, math.sqrt(2. / n))
86 | torch.nn.init.kaiming_normal_(m.weight)
87 | elif isinstance(m, SynchronizedBatchNorm2d):
88 | m.weight.data.fill_(1)
89 | m.bias.data.zero_()
90 | elif isinstance(m, nn.BatchNorm2d):
91 | m.weight.data.fill_(1)
92 | m.bias.data.zero_()
93 |
94 |
95 | def build_aspp(backbone, output_stride, BatchNorm):
96 | return ASPP(backbone, output_stride, BatchNorm)
--------------------------------------------------------------------------------
/models/include/deeplabv3plus_inc/modeling/backbone/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruirangerfan/GAL-DeepLabv3Plus/da613d0907ebf2908978a08f72b1a58e23caafb9/models/include/deeplabv3plus_inc/modeling/backbone/__init__.py
--------------------------------------------------------------------------------
/models/include/deeplabv3plus_inc/modeling/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.utils.model_zoo as model_zoo
5 | from models.include.deeplabv3plus_inc.modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 |
7 |
8 | class Bottleneck(nn.Module):
9 | expansion = 4
10 |
11 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
12 | super(Bottleneck, self).__init__()
13 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
14 | self.bn1 = BatchNorm(planes)
15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
16 | dilation=dilation, padding=dilation, bias=False)
17 | self.bn2 = BatchNorm(planes)
18 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
19 | self.bn3 = BatchNorm(planes * 4)
20 | self.relu = nn.ReLU(inplace=True)
21 | self.downsample = downsample
22 | self.stride = stride
23 | self.dilation = dilation
24 |
25 | def forward(self, x):
26 | residual = x
27 |
28 | out = self.conv1(x)
29 | out = self.bn1(out)
30 | out = self.relu(out)
31 |
32 | out = self.conv2(out)
33 | out = self.bn2(out)
34 | out = self.relu(out)
35 |
36 | out = self.conv3(out)
37 | out = self.bn3(out)
38 |
39 | if self.downsample is not None:
40 | residual = self.downsample(x)
41 |
42 | out += residual
43 | out = self.relu(out)
44 |
45 | return out
46 |
47 | class ResNet(nn.Module):
48 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True, num_ch=3):
49 | self.inplanes = 64
50 | super(ResNet, self).__init__()
51 |
52 | self.num_ch = num_ch
53 |
54 | blocks = [1, 2, 4]
55 | if output_stride == 16:
56 | strides = [1, 2, 2, 1]
57 | dilations = [1, 1, 1, 2]
58 | elif output_stride == 8:
59 | strides = [1, 2, 1, 1]
60 | dilations = [1, 1, 2, 4]
61 | else:
62 | raise NotImplementedError
63 |
64 | # Modules
65 | self.conv1 = nn.Conv2d(num_ch, 64, kernel_size=7, stride=2, padding=3, bias=False)
66 | self.bn1 = BatchNorm(64)
67 | self.relu = nn.ReLU(inplace=True)
68 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
69 |
70 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
71 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
72 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
73 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
74 | self._init_weight()
75 |
76 | if pretrained:
77 | self._load_pretrained_model()
78 |
79 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
80 | downsample = None
81 | if stride != 1 or self.inplanes != planes * block.expansion:
82 | downsample = nn.Sequential(
83 | nn.Conv2d(self.inplanes, planes * block.expansion,
84 | kernel_size=1, stride=stride, bias=False),
85 | BatchNorm(planes * block.expansion),
86 | )
87 |
88 | layers = []
89 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
90 | self.inplanes = planes * block.expansion
91 | for i in range(1, blocks):
92 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
93 |
94 | return nn.Sequential(*layers)
95 |
96 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
97 | downsample = None
98 | if stride != 1 or self.inplanes != planes * block.expansion:
99 | downsample = nn.Sequential(
100 | nn.Conv2d(self.inplanes, planes * block.expansion,
101 | kernel_size=1, stride=stride, bias=False),
102 | BatchNorm(planes * block.expansion),
103 | )
104 |
105 | layers = []
106 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
107 | downsample=downsample, BatchNorm=BatchNorm))
108 | self.inplanes = planes * block.expansion
109 | for i in range(1, len(blocks)):
110 | layers.append(block(self.inplanes, planes, stride=1,
111 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
112 |
113 | return nn.Sequential(*layers)
114 |
115 | def forward(self, input):
116 | x = self.conv1(input)
117 | x = self.bn1(x)
118 | x = self.relu(x)
119 | x = self.maxpool(x)
120 |
121 | x = self.layer1(x)
122 | low_level_feat = x
123 | x = self.layer2(x)
124 | x = self.layer3(x)
125 | x = self.layer4(x)
126 | return x, low_level_feat
127 |
128 | def _init_weight(self):
129 | for m in self.modules():
130 | if isinstance(m, nn.Conv2d):
131 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
132 | m.weight.data.normal_(0, math.sqrt(2. / n))
133 | elif isinstance(m, SynchronizedBatchNorm2d):
134 | m.weight.data.fill_(1)
135 | m.bias.data.zero_()
136 | elif isinstance(m, nn.BatchNorm2d):
137 | m.weight.data.fill_(1)
138 | m.bias.data.zero_()
139 |
140 | def _load_pretrained_model(self):
141 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth')
142 | model_dict = {}
143 | state_dict = self.state_dict()
144 | for k, v in pretrain_dict.items():
145 | if k == "conv1.weight" and self.num_ch == 1:
146 | model_dict[k] = torch.unsqueeze(torch.mean(v, dim=1), dim=1)
147 | continue
148 | if k in state_dict:
149 | model_dict[k] = v
150 | state_dict.update(model_dict)
151 | self.load_state_dict(state_dict)
152 |
153 |
154 | def ResNet50(output_stride, BatchNorm, pretrained=True, num_ch=1):
155 | """Constructs a ResNet-50 model.
156 | Args:
157 | pretrained (bool): If True, returns a model pre-trained on ImageNet
158 | """
159 | model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, pretrained=pretrained, num_ch=num_ch)
160 | return model
161 |
162 | def ResNet101(output_stride, BatchNorm, pretrained=True, num_ch=1):
163 | """Constructs a ResNet-101 model.
164 | Args:
165 | pretrained (bool): If True, returns a model pre-trained on ImageNet
166 | """
167 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained, num_ch=num_ch)
168 | return model
169 |
--------------------------------------------------------------------------------
/models/include/deeplabv3plus_inc/modeling/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5 |
6 |
7 | class Decoder(nn.Module):
8 | def __init__(self, num_classes, backbone, BatchNorm):
9 | super(Decoder, self).__init__()
10 | if backbone == 'resnet' or backbone == 'drn':
11 | low_level_inplanes = 256
12 | elif backbone == 'xception':
13 | low_level_inplanes = 128
14 | elif backbone == 'mobilenet':
15 | low_level_inplanes = 24
16 | else:
17 | raise NotImplementedError
18 |
19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
20 | self.bn1 = BatchNorm(48)
21 | self.relu = nn.ReLU()
22 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
23 | BatchNorm(256),
24 | nn.ReLU(),
25 | nn.Dropout(0.5),
26 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
27 | BatchNorm(256),
28 | nn.ReLU(),
29 | nn.Dropout(0.1),
30 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1))
31 | self._init_weight()
32 |
33 | def forward(self, x, low_level_feat):
34 | low_level_feat = self.conv1(low_level_feat)
35 | low_level_feat = self.bn1(low_level_feat)
36 | low_level_feat = self.relu(low_level_feat)
37 |
38 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
39 | x = torch.cat((x, low_level_feat), dim=1)
40 | x = self.last_conv(x)
41 |
42 | return x
43 |
44 | def _init_weight(self):
45 | for m in self.modules():
46 | if isinstance(m, nn.Conv2d):
47 | torch.nn.init.kaiming_normal_(m.weight)
48 | elif isinstance(m, SynchronizedBatchNorm2d):
49 | m.weight.data.fill_(1)
50 | m.bias.data.zero_()
51 | elif isinstance(m, nn.BatchNorm2d):
52 | m.weight.data.fill_(1)
53 | m.bias.data.zero_()
54 |
55 |
56 | def build_decoder(num_classes, backbone, BatchNorm):
57 | return Decoder(num_classes, backbone, BatchNorm)
58 |
--------------------------------------------------------------------------------
/models/include/deeplabv3plus_inc/modeling/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
--------------------------------------------------------------------------------
/models/include/deeplabv3plus_inc/modeling/sync_batchnorm/batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 |
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | from torch.nn.modules.batchnorm import _BatchNorm
17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18 |
19 | from .comm import SyncMaster
20 |
21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22 |
23 |
24 | def _sum_ft(tensor):
25 | """sum over the first and last dimention"""
26 | return tensor.sum(dim=0).sum(dim=-1)
27 |
28 |
29 | def _unsqueeze_ft(tensor):
30 | """add new dementions at the front and the tail"""
31 | return tensor.unsqueeze(0).unsqueeze(-1)
32 |
33 |
34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36 |
37 |
38 | class _SynchronizedBatchNorm(_BatchNorm):
39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41 |
42 | self._sync_master = SyncMaster(self._data_parallel_master)
43 |
44 | self._is_parallel = False
45 | self._parallel_id = None
46 | self._slave_pipe = None
47 |
48 | def forward(self, input):
49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50 | if not (self._is_parallel and self.training):
51 | return F.batch_norm(
52 | input, self.running_mean, self.running_var, self.weight, self.bias,
53 | self.training, self.momentum, self.eps)
54 |
55 | # Resize the input to (B, C, -1).
56 | input_shape = input.size()
57 | input = input.view(input.size(0), self.num_features, -1)
58 |
59 | # Compute the sum and square-sum.
60 | sum_size = input.size(0) * input.size(2)
61 | input_sum = _sum_ft(input)
62 | input_ssum = _sum_ft(input ** 2)
63 |
64 | # Reduce-and-broadcast the statistics.
65 | if self._parallel_id == 0:
66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
67 | else:
68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
69 |
70 | # Compute the output.
71 | if self.affine:
72 | # MJY:: Fuse the multiplication for speed.
73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
74 | else:
75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
76 |
77 | # Reshape it.
78 | return output.view(input_shape)
79 |
80 | def __data_parallel_replicate__(self, ctx, copy_id):
81 | self._is_parallel = True
82 | self._parallel_id = copy_id
83 |
84 | # parallel_id == 0 means master device.
85 | if self._parallel_id == 0:
86 | ctx.sync_master = self._sync_master
87 | else:
88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
89 |
90 | def _data_parallel_master(self, intermediates):
91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
92 |
93 | # Always using same "device order" makes the ReduceAdd operation faster.
94 | # Thanks to:: Tete Xiao (http://tetexiao.com/)
95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
96 |
97 | to_reduce = [i[1][:2] for i in intermediates]
98 | to_reduce = [j for i in to_reduce for j in i] # flatten
99 | target_gpus = [i[1].sum.get_device() for i in intermediates]
100 |
101 | sum_size = sum([i[1].sum_size for i in intermediates])
102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
104 |
105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
106 |
107 | outputs = []
108 | for i, rec in enumerate(intermediates):
109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))
110 |
111 | return outputs
112 |
113 | def _compute_mean_std(self, sum_, ssum, size):
114 | """Compute the mean and standard-deviation with sum and square-sum. This method
115 | also maintains the moving average on the master device."""
116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
117 | mean = sum_ / size
118 | sumvar = ssum - sum_ * mean
119 | unbias_var = sumvar / (size - 1)
120 | bias_var = sumvar / size
121 |
122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
124 |
125 | return mean, bias_var.clamp(self.eps) ** -0.5
126 |
127 |
128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
130 | mini-batch.
131 | .. math::
132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
134 | standard-deviation are reduced across all devices during training.
135 | For example, when one uses `nn.DataParallel` to wrap the network during
136 | training, PyTorch's implementation normalize the tensor on each device using
137 | the statistics only on that device, which accelerated the computation and
138 | is also easy to implement, but the statistics might be inaccurate.
139 | Instead, in this synchronized version, the statistics will be computed
140 | over all training samples distributed on multiple devices.
141 |
142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
143 | as the built-in PyTorch implementation.
144 | The mean and standard-deviation are calculated per-dimension over
145 | the mini-batches and gamma and beta are learnable parameter vectors
146 | of size C (where C is the input size).
147 | During training, this layer keeps a running estimate of its computed mean
148 | and variance. The running sum is kept with a default momentum of 0.1.
149 | During evaluation, this running mean/variance is used for normalization.
150 | Because the BatchNorm is done over the `C` dimension, computing statistics
151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
152 | Args:
153 | num_features: num_features from an expected input of size
154 | `batch_size x num_features [x width]`
155 | eps: a value added to the denominator for numerical stability.
156 | Default: 1e-5
157 | momentum: the value used for the running_mean and running_var
158 | computation. Default: 0.1
159 | affine: a boolean value that when set to ``True``, gives the layer learnable
160 | affine parameters. Default: ``True``
161 | Shape:
162 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
164 | Examples:
165 | >>> # With Learnable Parameters
166 | >>> m = SynchronizedBatchNorm1d(100)
167 | >>> # Without Learnable Parameters
168 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
169 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
170 | >>> output = m(input)
171 | """
172 |
173 | def _check_input_dim(self, input):
174 | if input.dim() != 2 and input.dim() != 3:
175 | raise ValueError('expected 2D or 3D input (got {}D input)'
176 | .format(input.dim()))
177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
178 |
179 |
180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
182 | of 3d inputs
183 | .. math::
184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
186 | standard-deviation are reduced across all devices during training.
187 | For example, when one uses `nn.DataParallel` to wrap the network during
188 | training, PyTorch's implementation normalize the tensor on each device using
189 | the statistics only on that device, which accelerated the computation and
190 | is also easy to implement, but the statistics might be inaccurate.
191 | Instead, in this synchronized version, the statistics will be computed
192 | over all training samples distributed on multiple devices.
193 |
194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
195 | as the built-in PyTorch implementation.
196 | The mean and standard-deviation are calculated per-dimension over
197 | the mini-batches and gamma and beta are learnable parameter vectors
198 | of size C (where C is the input size).
199 | During training, this layer keeps a running estimate of its computed mean
200 | and variance. The running sum is kept with a default momentum of 0.1.
201 | During evaluation, this running mean/variance is used for normalization.
202 | Because the BatchNorm is done over the `C` dimension, computing statistics
203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
204 | Args:
205 | num_features: num_features from an expected input of
206 | size batch_size x num_features x height x width
207 | eps: a value added to the denominator for numerical stability.
208 | Default: 1e-5
209 | momentum: the value used for the running_mean and running_var
210 | computation. Default: 0.1
211 | affine: a boolean value that when set to ``True``, gives the layer learnable
212 | affine parameters. Default: ``True``
213 | Shape:
214 | - Input: :math:`(N, C, H, W)`
215 | - Output: :math:`(N, C, H, W)` (same shape as input)
216 | Examples:
217 | >>> # With Learnable Parameters
218 | >>> m = SynchronizedBatchNorm2d(100)
219 | >>> # Without Learnable Parameters
220 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
222 | >>> output = m(input)
223 | """
224 |
225 | def _check_input_dim(self, input):
226 | if input.dim() != 4:
227 | raise ValueError('expected 4D input (got {}D input)'
228 | .format(input.dim()))
229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
230 |
231 |
232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
234 | of 4d inputs
235 | .. math::
236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
238 | standard-deviation are reduced across all devices during training.
239 | For example, when one uses `nn.DataParallel` to wrap the network during
240 | training, PyTorch's implementation normalize the tensor on each device using
241 | the statistics only on that device, which accelerated the computation and
242 | is also easy to implement, but the statistics might be inaccurate.
243 | Instead, in this synchronized version, the statistics will be computed
244 | over all training samples distributed on multiple devices.
245 |
246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
247 | as the built-in PyTorch implementation.
248 | The mean and standard-deviation are calculated per-dimension over
249 | the mini-batches and gamma and beta are learnable parameter vectors
250 | of size C (where C is the input size).
251 | During training, this layer keeps a running estimate of its computed mean
252 | and variance. The running sum is kept with a default momentum of 0.1.
253 | During evaluation, this running mean/variance is used for normalization.
254 | Because the BatchNorm is done over the `C` dimension, computing statistics
255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
256 | or Spatio-temporal BatchNorm
257 | Args:
258 | num_features: num_features from an expected input of
259 | size batch_size x num_features x depth x height x width
260 | eps: a value added to the denominator for numerical stability.
261 | Default: 1e-5
262 | momentum: the value used for the running_mean and running_var
263 | computation. Default: 0.1
264 | affine: a boolean value that when set to ``True``, gives the layer learnable
265 | affine parameters. Default: ``True``
266 | Shape:
267 | - Input: :math:`(N, C, D, H, W)`
268 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
269 | Examples:
270 | >>> # With Learnable Parameters
271 | >>> m = SynchronizedBatchNorm3d(100)
272 | >>> # Without Learnable Parameters
273 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
275 | >>> output = m(input)
276 | """
277 |
278 | def _check_input_dim(self, input):
279 | if input.dim() != 5:
280 | raise ValueError('expected 5D input (got {}D input)'
281 | .format(input.dim()))
282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
--------------------------------------------------------------------------------
/models/include/deeplabv3plus_inc/modeling/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from multiprocessing import Queue
12 | #import queue
13 | import collections
14 | import threading
15 |
16 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
17 |
18 |
19 | class FutureResult(object):
20 | """A thread-safe future implementation. Used only as one-to-one pipe."""
21 |
22 | def __init__(self):
23 | self._result = None
24 | self._lock = threading.Lock()
25 | self._cond = threading.Condition(self._lock)
26 |
27 | def put(self, result):
28 | with self._lock:
29 | assert self._result is None, 'Previous result has\'t been fetched.'
30 | self._result = result
31 | self._cond.notify()
32 |
33 | def get(self):
34 | with self._lock:
35 | if self._result is None:
36 | self._cond.wait()
37 |
38 | res = self._result
39 | self._result = None
40 | return res
41 |
42 |
43 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
44 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
45 |
46 |
47 | class SlavePipe(_SlavePipeBase):
48 | """Pipe for master-slave communication."""
49 |
50 | def run_slave(self, msg):
51 | self.queue.put((self.identifier, msg))
52 | ret = self.result.get()
53 | self.queue.put(True)
54 | return ret
55 |
56 |
57 | class SyncMaster(object):
58 | """An abstract `SyncMaster` object.
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 | Args:
70 | master_callback: a callback to be invoked after having collected messages from slave devices.
71 | """
72 | self._master_callback = master_callback
73 | self._queue = Queue()
74 | self._registry = collections.OrderedDict()
75 | self._activated = False
76 |
77 | def __getstate__(self):
78 | return {'master_callback': self._master_callback}
79 |
80 | def __setstate__(self, state):
81 | self.__init__(state['master_callback'])
82 |
83 | def register_slave(self, identifier):
84 | """
85 | Register an slave device.
86 | Args:
87 | identifier: an identifier, usually is the device id.
88 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
89 | """
90 | if self._activated:
91 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
92 | self._activated = False
93 | self._registry.clear()
94 | future = FutureResult()
95 | self._registry[identifier] = _MasterRegistry(future)
96 | return SlavePipe(identifier, self._queue, future)
97 |
98 | def run_master(self, master_msg):
99 | """
100 | Main entry for the master device in each forward pass.
101 | The messages were first collected from each devices (including the master device), and then
102 | an callback will be invoked to compute the message to be sent back to each devices
103 | (including the master device).
104 | Args:
105 | master_msg: the message that the master want to send to itself. This will be placed as the first
106 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
107 | Returns: the message to be sent back to the master device.
108 | """
109 | self._activated = True
110 |
111 | intermediates = [(0, master_msg)]
112 | for i in range(self.nr_slaves):
113 | intermediates.append(self._queue.get())
114 |
115 | results = self._master_callback(intermediates)
116 | assert results[0][0] == 0, 'The first result should belongs to the master.'
117 |
118 | for i, res in results:
119 | if i == 0:
120 | continue
121 | self._registry[i].result.put(res)
122 |
123 | for i in range(self.nr_slaves):
124 | assert self._queue.get() is True
125 |
126 | return results[0][1]
127 |
128 | @property
129 | def nr_slaves(self):
130 | return len(self._registry)
131 |
--------------------------------------------------------------------------------
/models/include/deeplabv3plus_inc/modeling/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
31 | Note that, as all modules are isomorphism, we assign each sub-module with a context
32 | (shared among multiple copies of this module on different devices).
33 | Through this context, different copies can share some information.
34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
35 | of any slave copies.
36 | """
37 | master_copy = modules[0]
38 | nr_modules = len(list(master_copy.modules()))
39 | ctxs = [CallbackContext() for _ in range(nr_modules)]
40 |
41 | for i, module in enumerate(modules):
42 | for j, m in enumerate(module.modules()):
43 | if hasattr(m, '__data_parallel_replicate__'):
44 | m.__data_parallel_replicate__(ctxs[j], i)
45 |
46 |
47 | class DataParallelWithCallback(DataParallel):
48 | """
49 | Data Parallel with a replication callback.
50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
51 | original `replicate` function.
52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
53 | Examples:
54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
56 | # sync_bn.__data_parallel_replicate__ will be invoked.
57 | """
58 |
59 | def replicate(self, module, device_ids):
60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
61 | execute_replication_callbacks(modules)
62 | return modules
63 |
64 |
65 | def patch_replication_callback(data_parallel):
66 | """
67 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
68 | Useful when you have customized `DataParallel` implementation.
69 | Examples:
70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
72 | > patch_replication_callback(sync_bn)
73 | # this is equivalent to
74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
76 | """
77 |
78 | assert isinstance(data_parallel, DataParallel)
79 |
80 | old_replicate = data_parallel.replicate
81 |
82 | @functools.wraps(old_replicate)
83 | def new_replicate(module, device_ids):
84 | modules = old_replicate(module, device_ids)
85 | execute_replication_callbacks(modules)
86 | return modules
87 |
88 | data_parallel.replicate = new_replicate
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | from torch.optim import lr_scheduler
5 | import torch.nn.functional as F
6 |
7 | from models.include.deeplabv3plus_inc.modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d
8 | from models.include.deeplabv3plus_inc.modeling.aspp import build_aspp
9 | from models.include.deeplabv3plus_inc.modeling.decoder import build_decoder
10 | from models.include.deeplabv3plus_inc.modeling.backbone import resnet
11 |
12 |
13 | ### help functions ###
14 | def get_scheduler(optimizer, opt):
15 | if opt.lr_policy == 'lambda':
16 | lambda_rule = lambda epoch: opt.lr_gamma ** ((epoch+1) // opt.lr_decay_epochs)
17 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
18 | elif opt.lr_policy == 'step':
19 | scheduler = lr_scheduler.StepLR(optimizer,step_size=opt.lr_decay_iters, gamma=0.1)
20 | else:
21 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
22 | return scheduler
23 |
24 | def init_net(net, gpu_ids=[]):
25 | if len(gpu_ids) > 0:
26 | assert(torch.cuda.is_available())
27 | net.to(gpu_ids[0])
28 | net = torch.nn.DataParallel(net, gpu_ids)
29 | return net
30 |
31 | def define_GALNet(num_labels, gpu_ids=[], input_channels=1, use_gal=True):
32 | net = GALDeepLabV3Plus(n_class=num_labels, input_channels=input_channels, use_gal=use_gal)
33 | return init_net(net, gpu_ids)
34 |
35 |
36 | # Ref: https://github.com/jfzhang95/pytorch-deeplab-xception
37 | def build_backbone(backbone, output_stride, BatchNorm, input_channels):
38 | if backbone == 'resnet':
39 | return resnet.ResNet50(output_stride, BatchNorm, num_ch=input_channels)
40 | else:
41 | raise NotImplementedError
42 |
43 | class GALDeepLabV3Plus(nn.Module):
44 | def __init__(self, n_class=2, backbone='resnet', output_stride=16, sync_bn=True, freeze_bn=False, input_channels=1, use_gal=True):
45 | super(GALDeepLabV3Plus, self).__init__()
46 |
47 | if backbone == 'drn':
48 | output_stride = 8
49 |
50 | if sync_bn == True:
51 | BatchNorm = SynchronizedBatchNorm2d
52 | else:
53 | BatchNorm = nn.BatchNorm2d
54 |
55 | self.backbone = build_backbone(backbone, output_stride, BatchNorm, input_channels)
56 | self.aspp = build_aspp(backbone, output_stride, BatchNorm)
57 | self.decoder = build_decoder(n_class, backbone, BatchNorm)
58 |
59 | self.use_gal = False
60 | if use_gal:
61 | print("Using GAL")
62 | self.use_gal = True
63 | self.gal = GAL(sync_bn=sync_bn, input_channels=2048)
64 |
65 | if freeze_bn:
66 | self.freeze_bn()
67 |
68 | def forward(self, input):
69 | input = input.float()
70 |
71 | x, low_level_feat = self.backbone(input)
72 |
73 | if self.use_gal:
74 | x = self.gal(x)
75 |
76 | x = self.aspp(x)
77 |
78 | x = self.decoder(x, low_level_feat)
79 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
80 |
81 | return x
82 |
83 | def freeze_bn(self):
84 | for m in self.modules():
85 | if isinstance(m, SynchronizedBatchNorm2d):
86 | m.eval()
87 | elif isinstance(m, nn.BatchNorm2d):
88 | m.eval()
89 |
90 |
91 | class GAL(nn.Module):
92 | def __init__(self, sync_bn=True, input_channels=2048):
93 | super(GAL, self).__init__()
94 | self.input_channels = input_channels
95 | if sync_bn == True:
96 | BatchNorm1d = SynchronizedBatchNorm1d
97 | BatchNorm2d = SynchronizedBatchNorm2d
98 | else:
99 | BatchNorm1d = nn.BatchNorm1d
100 | BatchNorm2d = nn.BatchNorm2d
101 |
102 | self.edge_aggregation_func = nn.Sequential(
103 | nn.Linear(4, 1),
104 | BatchNorm1d(1),
105 | nn.ReLU(inplace=True),
106 | )
107 | self.vertex_update_func = nn.Sequential(
108 | nn.Linear(2 * input_channels, input_channels // 2),
109 | BatchNorm1d(input_channels // 2),
110 | nn.ReLU(inplace=True),
111 | )
112 |
113 | self.edge_update_func = nn.Sequential(
114 | nn.Linear(2 * input_channels, input_channels // 2),
115 | BatchNorm1d(input_channels // 2),
116 | nn.ReLU(inplace=True),
117 | )
118 | self.update_edge_reduce_func = nn.Sequential(
119 | nn.Linear(4, 1),
120 | BatchNorm1d(1),
121 | nn.ReLU(inplace=True),
122 | )
123 |
124 | self.final_aggregation_layer = nn.Sequential(
125 | nn.Conv2d(input_channels + input_channels // 2, input_channels, kernel_size=1, stride=1, padding=0, bias=False),
126 | BatchNorm2d(input_channels),
127 | nn.ReLU(inplace=True),
128 | )
129 |
130 | self._init_weight()
131 |
132 | def forward(self, input):
133 | x = input
134 | B, C, H, W = x.size()
135 |
136 | vertex = input
137 | edge = torch.stack(
138 | (
139 | torch.cat((input[:,:,-1:], input[:,:,:-1]), dim=2),
140 | torch.cat((input[:,:,1:], input[:,:,:1]), dim=2),
141 | torch.cat((input[:,:,:,-1:], input[:,:,:,:-1]), dim=3),
142 | torch.cat((input[:,:,:,1:], input[:,:,:,:1]), dim=3)
143 | ), dim=-1
144 | ) * input.unsqueeze(dim=-1)
145 |
146 | aggregated_edge = self.edge_aggregation_func(
147 | edge.reshape(-1, 4)
148 | ).reshape((B, C, H, W))
149 | cat_feature_for_vertex = torch.cat((vertex, aggregated_edge), dim=1)
150 | update_vertex = self.vertex_update_func(
151 | cat_feature_for_vertex.permute(0, 2, 3, 1).reshape((-1, 2 * self.input_channels))
152 | ).reshape((B, H, W, self.input_channels // 2)).permute(0, 3, 1, 2)
153 |
154 | cat_feature_for_edge = torch.cat(
155 | (
156 | torch.stack((vertex, vertex, vertex, vertex), dim=-1),
157 | edge
158 | ), dim=1
159 | ).permute(0, 2, 3, 4, 1).reshape((-1, 2 * self.input_channels))
160 | update_edge = self.edge_update_func(cat_feature_for_edge).reshape((B, H, W, 4, C//2)).permute(0, 4, 1, 2, 3).reshape((-1, 4))
161 | update_edge_converted = self.update_edge_reduce_func(update_edge).reshape((B, C//2, H, W))
162 |
163 | update_feature = update_vertex * update_edge_converted
164 | output = self.final_aggregation_layer(
165 | torch.cat((x, update_feature), dim=1)
166 | )
167 |
168 | return output
169 |
170 | def _init_weight(self):
171 | for m in self.modules():
172 | if isinstance(m, nn.Conv2d):
173 | torch.nn.init.kaiming_normal_(m.weight)
174 | elif isinstance(m, nn.Linear):
175 | torch.nn.init.kaiming_normal_(m.weight)
176 | elif isinstance(m, SynchronizedBatchNorm1d):
177 | m.weight.data.fill_(1)
178 | m.bias.data.zero_()
179 | elif isinstance(m, nn.BatchNorm1d):
180 | m.weight.data.fill_(1)
181 | m.bias.data.zero_()
182 | elif isinstance(m, SynchronizedBatchNorm2d):
183 | m.weight.data.fill_(1)
184 | m.bias.data.zero_()
185 | elif isinstance(m, nn.BatchNorm2d):
186 | m.weight.data.fill_(1)
187 | m.bias.data.zero_()
188 |
189 |
190 | class SegmantationLoss(nn.Module):
191 | def __init__(self, class_weights=None):
192 | super(SegmantationLoss, self).__init__()
193 | self.loss = nn.CrossEntropyLoss(weight=class_weights)
194 | def __call__(self, output, target):
195 | return self.loss(output, target)
196 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruirangerfan/GAL-DeepLabv3Plus/da613d0907ebf2908978a08f72b1a58e23caafb9/options/__init__.py
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 | import models
6 | import data
7 |
8 |
9 | class BaseOptions():
10 | def __init__(self):
11 | self.initialized = False
12 |
13 | def initialize(self, parser):
14 | parser.add_argument('--batch_size', type=int, default=2, help='input batch size')
15 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
16 | parser.add_argument('--name', type=str, default='pothole', help='name of the experiment. It decides where to store samples and models')
17 | parser.add_argument('--input', type=str, default='tdisp', help='chooses input images')
18 | parser.add_argument('--dataset', type=str, default='pothole', help='chooses which dataset to load.')
19 | parser.add_argument('--model', type=str, default='galnet', help='chooses which model to use.')
20 | parser.add_argument('--gal', action='store_true', help='if true, use gal')
21 | parser.add_argument('--epoch', type=str, default='best', help='chooses which epoch to load')
22 | parser.add_argument('--num_threads', default=2, type=int, help='# threads for loading data')
23 | parser.add_argument('--checkpoints_dir', type=str, default='./runs', help='models and records are saved here')
24 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
25 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
26 | parser.add_argument('--seed', type=int, default=0, help='seed for random generators')
27 | self.initialized = True
28 | return parser
29 |
30 | def gather_options(self):
31 | # initialize parser with basic options
32 | if not self.initialized:
33 | parser = argparse.ArgumentParser(
34 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
35 | parser = self.initialize(parser)
36 |
37 | # get the basic options
38 | opt, _ = parser.parse_known_args()
39 |
40 | # modify model-related parser options
41 | model_name = opt.model
42 | model_option_setter = models.get_option_setter(model_name)
43 | parser = model_option_setter(parser, self.isTrain)
44 | opt, _ = parser.parse_known_args() # parse again with the new defaults
45 |
46 | # modify dataset-related parser options
47 | dataset_name = opt.dataset
48 | dataset_option_setter = data.get_option_setter(dataset_name)
49 | parser = dataset_option_setter(parser, self.isTrain)
50 |
51 | self.parser = parser
52 |
53 | return parser.parse_args()
54 |
55 | def print_options(self, opt):
56 | message = ''
57 | message += '----------------- Options ---------------\n'
58 | for k, v in sorted(vars(opt).items()):
59 | comment = ''
60 | default = self.parser.get_default(k)
61 | if v != default:
62 | comment = '\t[default: %s]' % str(default)
63 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
64 | message += '----------------- End -------------------'
65 | print(message)
66 |
67 | # save to the disk
68 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
69 | util.mkdirs(expr_dir)
70 | file_name = os.path.join(expr_dir, 'opt.txt')
71 | with open(file_name, 'wt') as opt_file:
72 | opt_file.write(message)
73 | opt_file.write('\n')
74 |
75 | def parse(self):
76 | opt = self.gather_options()
77 | opt.isTrain = self.isTrain # train or test
78 |
79 | self.print_options(opt)
80 |
81 | # set gpu ids
82 | str_ids = opt.gpu_ids.split(',')
83 | opt.gpu_ids = []
84 | for str_id in str_ids:
85 | id = int(str_id)
86 | if id >= 0:
87 | opt.gpu_ids.append(id)
88 | if len(opt.gpu_ids) > 0:
89 | torch.cuda.set_device(opt.gpu_ids[0])
90 |
91 | self.opt = opt
92 | return self.opt
93 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | def initialize(self, parser):
6 | parser = BaseOptions.initialize(self, parser)
7 | parser.add_argument('--results_dir', type=str, default='./testresults/', help='saves results here.')
8 | parser.add_argument('--phase', type=str, default='test', help='train, val, test')
9 | self.isTrain = False
10 | return parser
11 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from typing_extensions import Required
2 | from .base_options import BaseOptions
3 |
4 |
5 | class TrainOptions(BaseOptions):
6 | def initialize(self, parser):
7 | parser = BaseOptions.initialize(self, parser)
8 | parser.add_argument('--print_freq', type=int, default=10, help='frequency of showing training results on console')
9 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
10 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count')
11 | parser.add_argument('--phase', type=str, default='train', help='train, val, test')
12 | parser.add_argument('--nepoch', type=int, default=1000, help='maximum epochs')
13 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
14 | parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for optimizer')
15 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum factor for SGD')
16 | parser.add_argument('--weight_decay', type=float, default=0.0005, help='momentum factor for optimizer')
17 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine')
18 | parser.add_argument('--lr_decay_iters', type=int, default=5000000, help='multiply by a gamma every lr_decay_iters iterations')
19 | parser.add_argument('--lr_decay_epochs', type=int, default=25, help='multiply by a gamma every lr_decay_epoch epochs')
20 | parser.add_argument('--lr_gamma', type=float, default=0.9, help='gamma factor for lr_scheduler')
21 | self.isTrain = True
22 | return parser
23 |
--------------------------------------------------------------------------------
/scripts/test_gal.sh:
--------------------------------------------------------------------------------
1 | python3 test.py --dataset pothole --model galnet --input tdisp --name tdisp_gal --gal --epoch best
--------------------------------------------------------------------------------
/scripts/train_gal.sh:
--------------------------------------------------------------------------------
1 | python3 train.py --dataset pothole --model galnet --input tdisp --name tdisp_gal --gal
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from options.test_options import TestOptions
3 | from data import CreateDataLoader
4 | from models import create_model
5 | from util.util import confusion_matrix, getScores, save_images
6 | import torch
7 | import numpy as np
8 |
9 |
10 | if __name__ == '__main__':
11 | opt = TestOptions().parse()
12 | opt.num_threads = 1
13 | opt.batch_size = 1
14 | opt.serial_batches = True # no shuffle
15 | opt.isTrain = False
16 |
17 | save_dir = os.path.join(opt.results_dir, opt.name, opt.phase + '_' + opt.epoch)
18 | if not os.path.exists(save_dir):
19 | os.makedirs(save_dir)
20 |
21 | data_loader = CreateDataLoader(opt)
22 | dataset = data_loader.load_data()
23 | model = create_model(opt, dataset.dataset)
24 | model.setup(opt)
25 | model.eval()
26 |
27 | test_loss_iter = []
28 | epoch_iter = 0
29 | conf_mat = np.zeros((dataset.dataset.num_labels, dataset.dataset.num_labels), dtype=np.float32)
30 | with torch.no_grad():
31 | for i, data in enumerate(dataset):
32 | model.set_input(data)
33 | model.forward()
34 | model.get_loss()
35 | epoch_iter += opt.batch_size
36 | gt = model.label.cpu().int().numpy()
37 | _, pred = torch.max(model.output.data.cpu(), 1)
38 | pred = pred.float().detach().int().numpy()
39 | save_images(save_dir, model.get_current_visuals(), model.get_image_names())
40 | conf_mat += confusion_matrix(gt, pred, dataset.dataset.num_labels)
41 |
42 | test_loss_iter.append(model.loss_segmentation)
43 | print('Epoch {0:}, iters: {1:}/{2:}, loss: {3:.3f} '.format(opt.epoch,
44 | epoch_iter,
45 | len(dataset) * opt.batch_size,
46 | test_loss_iter[-1]), end='\r')
47 |
48 | avg_test_loss = torch.mean(torch.stack(test_loss_iter))
49 | print ('Epoch {0:} test loss: {1:.3f} '.format(opt.epoch, avg_test_loss))
50 | globalacc, pre, recall, F_score, iou = getScores(conf_mat)
51 | print ('Epoch {0:} glob acc : {1:.3f}, pre : {2:.3f}, recall : {3:.3f}, F_score : {4:.3f}, IoU : {5:.3f}'.format(opt.epoch, globalacc, pre, recall, F_score, iou))
52 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | from options.train_options import TrainOptions
3 | from data import CreateDataLoader
4 | from models import create_model
5 | from util.util import confusion_matrix, getScores, tensor2labelim, tensor2im, print_current_losses
6 | import numpy as np
7 | import random
8 | import torch
9 | import os
10 | from tensorboardX import SummaryWriter
11 |
12 |
13 | if __name__ == '__main__':
14 | train_opt = TrainOptions().parse()
15 |
16 | np.random.seed(train_opt.seed)
17 | random.seed(train_opt.seed)
18 | torch.manual_seed(train_opt.seed)
19 | torch.cuda.manual_seed(train_opt.seed)
20 |
21 | train_data_loader = CreateDataLoader(train_opt)
22 | train_dataset = train_data_loader.load_data()
23 | train_dataset_size = len(train_data_loader)
24 | print('#training images = %d' % train_dataset_size)
25 |
26 | valid_opt = TrainOptions().parse()
27 | valid_opt.phase = 'val'
28 | valid_opt.batch_size = 1
29 | valid_opt.num_threads = 1
30 | valid_opt.serial_batches = True
31 | valid_opt.isTrain = False
32 | valid_data_loader = CreateDataLoader(valid_opt)
33 | valid_dataset = valid_data_loader.load_data()
34 | valid_dataset_size = len(valid_data_loader)
35 | print('#validation images = %d' % valid_dataset_size)
36 |
37 | writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name))
38 |
39 | model = create_model(train_opt, train_dataset.dataset)
40 | model.setup(train_opt)
41 | total_steps = 0
42 | tfcount = 0
43 | iou_max = 0
44 | for epoch in range(train_opt.epoch_count, train_opt.nepoch + 1):
45 | ### Training on the training set ###
46 | model.train()
47 | epoch_start_time = time.time()
48 | iter_data_time = time.time()
49 | epoch_iter = 0
50 | train_loss_iter = []
51 | for i, data in enumerate(train_dataset):
52 | iter_start_time = time.time()
53 | if total_steps % train_opt.print_freq == 0:
54 | t_data = iter_start_time - iter_data_time
55 | total_steps += train_opt.batch_size
56 | epoch_iter += train_opt.batch_size
57 | model.set_input(data)
58 | model.optimize_parameters()
59 |
60 | if total_steps % train_opt.print_freq == 0:
61 | tfcount = tfcount + 1
62 | losses = model.get_current_losses()
63 | train_loss_iter.append(losses["segmentation"])
64 | t = (time.time() - iter_start_time) / train_opt.batch_size
65 | print_current_losses(epoch, epoch_iter, losses, t, t_data)
66 | # There are several whole_loss values shown in tensorboard in one epoch,
67 | # to help better see the optimization phase
68 | writer.add_scalar('train/whole_loss', losses["segmentation"], tfcount)
69 |
70 | iter_data_time = time.time()
71 |
72 | mean_loss = np.mean(train_loss_iter)
73 | # One average training loss value in tensorboard in one epoch
74 | writer.add_scalar('train/mean_loss', mean_loss, epoch)
75 |
76 | palet_file = 'datasets/palette.txt'
77 | impalette = list(np.genfromtxt(palet_file,dtype=np.uint8).reshape(3*256))
78 | tempDict = model.get_current_visuals()
79 | rgb = tensor2im(tempDict['rgb_image'])
80 | tdisp = tensor2im(tempDict['tdisp_image'])
81 | label = tensor2labelim(tempDict['label'], impalette)
82 | output = tensor2labelim(tempDict['output'], impalette)
83 | image_numpy = np.concatenate((rgb, tdisp, label, output), axis=1)
84 | image_numpy = image_numpy.astype(np.float32) / 255
85 | writer.add_image('Epoch' + str(epoch), image_numpy, dataformats='HWC') # show training images in tensorboard
86 |
87 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, train_opt.nepoch, time.time() - epoch_start_time))
88 | model.update_learning_rate()
89 |
90 | ### Evaluation on the validation set ###
91 | model.eval()
92 | valid_loss_iter = []
93 | epoch_iter = 0
94 | conf_mat = np.zeros((valid_dataset.dataset.num_labels, valid_dataset.dataset.num_labels), dtype=np.float32)
95 | with torch.no_grad():
96 | for i, data in enumerate(valid_dataset):
97 | model.set_input(data)
98 | model.forward()
99 | model.get_loss()
100 | epoch_iter += valid_opt.batch_size
101 | gt = model.label.cpu().int().numpy()
102 | _, pred = torch.max(model.output.data.cpu(), 1)
103 | pred = pred.float().detach().int().numpy()
104 |
105 | conf_mat += confusion_matrix(gt, pred, valid_dataset.dataset.num_labels)
106 | losses = model.get_current_losses()
107 | valid_loss_iter.append(model.loss_segmentation)
108 | print('valid epoch {0:}, iters: {1:}/{2:} '.format(epoch, epoch_iter, len(valid_dataset) * valid_opt.batch_size), end='\r')
109 |
110 | avg_valid_loss = torch.mean(torch.stack(valid_loss_iter))
111 | globalacc, pre, recall, F_score, iou = getScores(conf_mat)
112 |
113 | # Record performance on the validation set
114 | writer.add_scalar('valid/loss', avg_valid_loss, epoch)
115 | writer.add_scalar('valid/global_acc', globalacc, epoch)
116 | writer.add_scalar('valid/pre', pre, epoch)
117 | writer.add_scalar('valid/recall', recall, epoch)
118 | writer.add_scalar('valid/F_score', F_score, epoch)
119 | writer.add_scalar('valid/iou', iou, epoch)
120 |
121 | # Save the best model according to the F-score, and record corresponding epoch number in tensorboard
122 | if iou > iou_max:
123 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
124 | model.save_networks('best')
125 | iou_max = iou
126 | writer.add_text('best model', str(epoch))
127 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruirangerfan/GAL-DeepLabv3Plus/da613d0907ebf2908978a08f72b1a58e23caafb9/util/__init__.py
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import os
6 | import cv2
7 |
8 |
9 | def save_images(save_dir, visuals, image_name):
10 | """save images to disk"""
11 | image_name = image_name[0]
12 | palet_file = 'datasets/palette.txt'
13 | impalette = list(np.genfromtxt(palet_file, dtype=np.uint8).reshape(3*256))
14 |
15 | for label, im_data in visuals.items():
16 | if label == 'output':
17 | im = tensor2labelim(im_data, impalette)
18 | cv2.imwrite(os.path.join(save_dir, image_name), cv2.cvtColor(im, cv2.COLOR_RGB2BGR))
19 |
20 | def tensor2im(input_image, imtype=np.uint8):
21 | """Converts a image Tensor into an image array (numpy)"""
22 | if isinstance(input_image, torch.Tensor):
23 | image_tensor = input_image.data
24 | else:
25 | return input_image
26 | image_numpy = image_tensor[0].cpu().float().numpy()
27 | if image_numpy.shape[0] == 1:
28 | image_numpy = np.tile(image_numpy, (3, 1, 1))
29 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)))* 255.0
30 | return image_numpy.astype(imtype)
31 |
32 | def tensor2labelim(label_tensor, impalette, imtype=np.uint8):
33 | """Converts a label Tensor into an image array (numpy),
34 | we use a palette to color the label images"""
35 | if len(label_tensor.shape) == 4:
36 | _, label_tensor = torch.max(label_tensor.data.cpu(), 1)
37 |
38 | label_numpy = label_tensor[0].cpu().float().detach().numpy()
39 | label_image = Image.fromarray(label_numpy.astype(np.uint8))
40 | label_image = label_image.convert("P")
41 | label_image.putpalette(impalette)
42 | label_image = label_image.convert("RGB")
43 | return np.array(label_image).astype(imtype)
44 |
45 | def print_current_losses(epoch, i, losses, t, t_data):
46 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data)
47 | for k, v in losses.items():
48 | message += '%s: %.3f ' % (k, v)
49 | print(message)
50 |
51 |
52 | def mkdirs(paths):
53 | if isinstance(paths, list) and not isinstance(paths, str):
54 | for path in paths:
55 | mkdir(path)
56 | else:
57 | mkdir(paths)
58 |
59 | def mkdir(path):
60 | if not os.path.exists(path):
61 | os.makedirs(path)
62 |
63 |
64 | def confusion_matrix(x, y, n, ignore_label=None, mask=None):
65 | if mask is None:
66 | mask = np.ones_like(x) == 1
67 | k = (x >= 0) & (y < n) & (x != ignore_label) & (mask.astype(np.bool))
68 | return np.bincount(n * x[k].astype(int) + y[k], minlength=n**2).reshape(n, n)
69 |
70 | def getScores(conf_matrix):
71 | if conf_matrix.sum() == 0:
72 | return 0, 0, 0, 0, 0
73 | with np.errstate(divide='ignore',invalid='ignore'):
74 | globalacc = np.diag(conf_matrix).sum() / conf_matrix.sum().astype(np.float32)
75 | classpre = np.diag(conf_matrix) / conf_matrix.sum(0).astype(np.float32)
76 | classrecall = np.diag(conf_matrix) / conf_matrix.sum(1).astype(np.float32)
77 | IU = np.diag(conf_matrix) / (conf_matrix.sum(1) + conf_matrix.sum(0) - np.diag(conf_matrix)).astype(np.float32)
78 | pre = classpre[1]
79 | recall = classrecall[1]
80 | iou = IU[1]
81 | F_score = 2*(recall*pre)/(recall+pre)
82 | return globalacc, pre, recall, F_score, iou
83 |
--------------------------------------------------------------------------------