├── .gitignore ├── LICENSE ├── README.md ├── data ├── __init__.py ├── base_data_loader.py ├── base_dataset.py └── pothole_dataset.py ├── datasets └── palette.txt ├── doc ├── GAL-DeepLabv3+.png └── GAL.png ├── models ├── __init__.py ├── base_model.py ├── galnet_model.py ├── include │ ├── __init__.py │ └── deeplabv3plus_inc │ │ ├── __init__.py │ │ └── modeling │ │ ├── __init__.py │ │ ├── aspp.py │ │ ├── backbone │ │ ├── __init__.py │ │ └── resnet.py │ │ ├── decoder.py │ │ └── sync_batchnorm │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── comm.py │ │ └── replicate.py └── networks.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── scripts ├── test_gal.sh └── train_gal.sh ├── test.py ├── train.py └── util ├── __init__.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Rui (Ranger) Fan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GAL-DeepLabv3Plus 2 | 3 | ## Introduction 4 | 5 | This is the official PyTorch implementation of **[Graph Attention Layer Evolves Semantic Segmentation for Road Pothole Detection: A Benchmark and Algorithms](https://ieeexplore.ieee.org/document/9547682)**, published on IEEE T-IP in 2021. 6 | 7 | In this repository, we provide the training and testing setups on the [pothole dataset](https://drive.google.com/file/d/1ofp-44LnYTDByOuVMOc2hBrUCUjncg3k/view?usp=sharing) ([paper](https://ieeexplore.ieee.org/abstract/document/8809907)). We have tested our code in Python 3.8.10, CUDA 11.1, and PyTorch 1.10.1. 8 | 9 |

10 | 11 |

12 | 13 |

14 | 15 |

16 | 17 | ## Setup 18 | 19 | Please setup the pothole dataset and the pretrained weight according to the following folder structure: 20 | 21 | ``` 22 | GAL-DeepLabv3plus 23 | |-- data 24 | |-- datasets 25 | | |-- pothole 26 | |-- models 27 | |-- options 28 | |-- runs 29 | | |-- tdisp_gal 30 | ... 31 | ``` 32 | 33 | The pothole dataset `datasets/pothole` can be downloaded from [here](https://drive.google.com/file/d/1ofp-44LnYTDByOuVMOc2hBrUCUjncg3k/view?usp=sharing), and the pretrained weight `runs/tdisp_gal` for our GAL-DeepLabv3+ can be downloaded from [here](https://drive.google.com/file/d/1wmgPUymOOPUWovwIwLdIg4hf0jsWyYja/view?usp=sharing). 34 | 35 | ## Usage 36 | 37 | ### Testing on the Pothole Dataset 38 | 39 | For testing, please first setup the `runs/tdisp_gal` and the `datasets/pothole` folders as mentioned above. Then, run the following script: 40 | 41 | ``` 42 | bash ./scripts/test_gal.sh 43 | ``` 44 | 45 | to test GAL-DeepLabv3+ with the transformed disparity images. The prediction results are stored in `testresults`. 46 | 47 | ### Training on the Pothole Dataset 48 | 49 | For training, please first setup the `datasets/pothole` folder as mentioned above. Then, run the following script: 50 | 51 | ``` 52 | bash ./scripts/train_gal.sh 53 | ``` 54 | 55 | to train GAL-DeepLabv3+ with the transformed disparity images. The weights and the tensorboard record containing the loss curves as well as the performance on the validation set will be saved in `runs`. 56 | 57 | ## Citation 58 | 59 | If you use this code for your research, please cite our paper: 60 | 61 | ``` 62 | @article{fan2021graph, 63 | title = {Graph Attention Layer Evolves Semantic Segmentation for Road Pothole Detection: A Benchmark and Algorithms}, 64 | author = {Fan, Rui and Wang, Hengli and Wang, Yuan and Liu, Ming and Pitas, Ioannis}, 65 | journal = {IEEE Transactions on Image Processing}, 66 | volume = {30}, 67 | number = {}, 68 | pages = {8144-8154}, 69 | year = {2021}, 70 | publisher = {IEEE}, 71 | doi = {10.1109/TIP.2021.3112316} 72 | } 73 | ``` 74 | If you use the pothole dataset for your research, please cite our papers: 75 | 76 | ``` 77 | @article{fan2019pothole, 78 | title={Pothole detection based on disparity transformation and road surface modeling}, 79 | author={Fan, Rui and Ozgunalp, Umar and Hosking, Brett and Liu, Ming and Pitas, Ioannis}, 80 | journal={IEEE Transactions on Image Processing}, 81 | volume={29}, 82 | pages={897--908}, 83 | year={2019}, 84 | publisher={IEEE} 85 | } 86 | @article{fan2019road, 87 | title={Road damage detection based on unsupervised disparity map segmentation}, 88 | author={Fan, Rui and Liu, Ming}, 89 | journal={IEEE Transactions on Intelligent Transportation Systems}, 90 | volume={21}, 91 | number={11}, 92 | pages={4906--4911}, 93 | year={2019}, 94 | publisher={IEEE} 95 | } 96 | @article{fan2018road, 97 | title={Road surface 3D reconstruction based on dense subpixel disparity map estimation}, 98 | author={Fan, Rui and Ai, Xiao and Dahnoun, Naim}, 99 | journal={IEEE Transactions on Image Processing}, 100 | volume={27}, 101 | number={6}, 102 | pages={3025--3035}, 103 | year={2018}, 104 | publisher={IEEE} 105 | } 106 | ``` 107 | 108 | 109 | ## Acknowledgement 110 | 111 | Our code is inspired by [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix), [pytorch_segmentation](https://github.com/yassouali/pytorch_segmentation), [pytorch-deeplab-xception 112 | ](https://github.com/jfzhang95/pytorch-deeplab-xception), and [RTFNet](https://github.com/yuxiangsun/RTFNet). 113 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch.utils.data 3 | from data.base_data_loader import BaseDataLoader 4 | from data.base_dataset import BaseDataset 5 | import numpy 6 | 7 | 8 | def find_dataset_using_name(dataset_name): 9 | # Given the option --dataset [datasetname], 10 | # the file "data/datasetname_dataset.py" 11 | # will be imported. 12 | dataset_filename = "data." + dataset_name + "_dataset" 13 | datasetlib = importlib.import_module(dataset_filename) 14 | 15 | # In the file, the class called DatasetNameDataset() will 16 | # be instantiated. It has to be a subclass of BaseDataset, 17 | # and it is case-insensitive. 18 | dataset = None 19 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 20 | for name, cls in datasetlib.__dict__.items(): 21 | if name.lower() == target_dataset_name.lower() \ 22 | and issubclass(cls, BaseDataset): 23 | dataset = cls 24 | 25 | if dataset is None: 26 | print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 27 | exit(0) 28 | 29 | return dataset 30 | 31 | def get_option_setter(dataset_name): 32 | dataset_class = find_dataset_using_name(dataset_name) 33 | return dataset_class.modify_commandline_options 34 | 35 | def create_dataset(opt): 36 | dataset = find_dataset_using_name(opt.dataset) 37 | instance = dataset() 38 | instance.initialize(opt) 39 | print("dataset [%s] was created" % (instance.name())) 40 | return instance 41 | 42 | def CreateDataLoader(opt): 43 | data_loader = CustomDatasetDataLoader() 44 | data_loader.initialize(opt) 45 | return data_loader 46 | 47 | 48 | # Wrapper class of Dataset class that performs 49 | # multi-threaded data loading 50 | class CustomDatasetDataLoader(BaseDataLoader): 51 | def name(self): 52 | return 'CustomDatasetDataLoader' 53 | 54 | def initialize(self, opt): 55 | BaseDataLoader.initialize(self, opt) 56 | self.dataset = create_dataset(opt) 57 | self.dataloader = torch.utils.data.DataLoader( 58 | self.dataset, 59 | batch_size=opt.batch_size, 60 | shuffle=not opt.serial_batches, 61 | num_workers=int(opt.num_threads), 62 | drop_last=True, 63 | worker_init_fn=lambda worker_id: numpy.random.seed(opt.seed + worker_id)) 64 | 65 | def load_data(self): 66 | return self 67 | 68 | def __len__(self): 69 | return len(self.dataset) 70 | 71 | def __iter__(self): 72 | for i, data in enumerate(self.dataloader): 73 | yield data 74 | -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | class BaseDataLoader(): 2 | def __init__(self): 3 | pass 4 | 5 | def initialize(self, opt): 6 | self.opt = opt 7 | pass 8 | 9 | def load_data(): 10 | return None 11 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | 4 | class BaseDataset(data.Dataset): 5 | def __init__(self): 6 | super(BaseDataset, self).__init__() 7 | 8 | def name(self): 9 | return 'BaseDataset' 10 | 11 | @staticmethod 12 | def modify_commandline_options(parser, is_train): 13 | return parser 14 | 15 | def initialize(self, opt): 16 | pass 17 | 18 | def __len__(self): 19 | return 0 20 | -------------------------------------------------------------------------------- /data/pothole_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torchvision.transforms as transforms 3 | import torch 4 | import cv2 5 | import numpy as np 6 | from data.base_dataset import BaseDataset 7 | 8 | 9 | class potholedataset(BaseDataset): 10 | """dataloader for pothole dataset""" 11 | @staticmethod 12 | def modify_commandline_options(parser, is_train): 13 | return parser 14 | 15 | def initialize(self, opt): 16 | self.opt = opt 17 | self.batch_size = opt.batch_size 18 | self.num_labels = 2 19 | 20 | if opt.phase == "train": 21 | self.image_list = np.arange(1, 43) 22 | elif opt.phase == "val": 23 | self.image_list = np.arange(43, 56) 24 | else: 25 | self.image_list = np.arange(43, 56) 26 | 27 | def __getitem__(self, index): 28 | base_dir = "./datasets/pothole" 29 | name = str(self.image_list[index]).zfill(2) + ".png" 30 | 31 | rgb_image = cv2.cvtColor(cv2.imread(os.path.join(base_dir, 'rgb', name)), cv2.COLOR_BGR2RGB) 32 | tdisp_image = cv2.imread(os.path.join(base_dir, 'tdisp', name), cv2.IMREAD_ANYDEPTH) 33 | label_image = cv2.cvtColor(cv2.imread(os.path.join(base_dir, 'label', name)), cv2.COLOR_BGR2RGB) 34 | 35 | label = np.zeros((label_image.shape[0], label_image.shape[1]), dtype=np.uint8) 36 | label[label_image[:, :, 0] > 0] = 1 37 | 38 | rgb_image = rgb_image.astype(np.float32) / 255 39 | tdisp_image = tdisp_image.astype(np.float32) / 65535 40 | rgb_image = transforms.ToTensor()(rgb_image) 41 | tdisp_image = transforms.ToTensor()(tdisp_image) 42 | label = torch.from_numpy(label) 43 | label = label.type(torch.LongTensor) 44 | 45 | # return a dictionary containing useful information 46 | # input rgb images, tdisp images, and labels for training; 47 | # 'path': image name for saving predictions 48 | return {'rgb_image': rgb_image, 'tdisp_image': tdisp_image, 'label': label, 49 | 'path': name} 50 | 51 | def __len__(self): 52 | return len(self.image_list) 53 | 54 | def name(self): 55 | return 'pothole' 56 | -------------------------------------------------------------------------------- /datasets/palette.txt: -------------------------------------------------------------------------------- 1 | 0 0 0 2 | 128 0 0 3 | 0 128 0 4 | 128 128 0 5 | 0 0 128 6 | 128 0 128 7 | 0 128 128 8 | 128 128 128 9 | 64 0 0 10 | 192 0 0 11 | 64 128 0 12 | 192 128 0 13 | 64 0 128 14 | 192 0 128 15 | 64 128 128 16 | 192 128 128 17 | 0 64 0 18 | 128 64 0 19 | 0 192 0 20 | 128 192 0 21 | 0 64 128 22 | 128 64 128 23 | 0 192 128 24 | 128 192 128 25 | 64 64 0 26 | 192 64 0 27 | 64 192 0 28 | 192 192 0 29 | 64 64 128 30 | 192 64 128 31 | 64 192 128 32 | 192 192 128 33 | 0 0 64 34 | 128 0 64 35 | 0 128 64 36 | 128 128 64 37 | 0 0 192 38 | 128 0 192 39 | 0 128 192 40 | 128 128 192 41 | 64 0 64 42 | 192 0 64 43 | 64 128 64 44 | 192 128 64 45 | 64 0 192 46 | 192 0 192 47 | 64 128 192 48 | 192 128 192 49 | 0 64 64 50 | 128 64 64 51 | 0 192 64 52 | 128 192 64 53 | 0 64 192 54 | 128 64 192 55 | 0 192 192 56 | 128 192 192 57 | 64 64 64 58 | 192 64 64 59 | 64 192 64 60 | 192 192 64 61 | 64 64 192 62 | 192 64 192 63 | 64 192 192 64 | 192 192 192 65 | 32 0 0 66 | 160 0 0 67 | 32 128 0 68 | 160 128 0 69 | 32 0 128 70 | 160 0 128 71 | 32 128 128 72 | 160 128 128 73 | 96 0 0 74 | 224 0 0 75 | 96 128 0 76 | 224 128 0 77 | 96 0 128 78 | 224 0 128 79 | 96 128 128 80 | 224 128 128 81 | 32 64 0 82 | 160 64 0 83 | 32 192 0 84 | 160 192 0 85 | 32 64 128 86 | 160 64 128 87 | 32 192 128 88 | 160 192 128 89 | 96 64 0 90 | 224 64 0 91 | 96 192 0 92 | 224 192 0 93 | 96 64 128 94 | 224 64 128 95 | 96 192 128 96 | 224 192 128 97 | 32 0 64 98 | 160 0 64 99 | 32 128 64 100 | 160 128 64 101 | 32 0 192 102 | 160 0 192 103 | 32 128 192 104 | 160 128 192 105 | 96 0 64 106 | 224 0 64 107 | 96 128 64 108 | 224 128 64 109 | 96 0 192 110 | 224 0 192 111 | 96 128 192 112 | 224 128 192 113 | 32 64 64 114 | 160 64 64 115 | 32 192 64 116 | 160 192 64 117 | 32 64 192 118 | 160 64 192 119 | 32 192 192 120 | 160 192 192 121 | 96 64 64 122 | 224 64 64 123 | 96 192 64 124 | 224 192 64 125 | 96 64 192 126 | 224 64 192 127 | 96 192 192 128 | 224 192 192 129 | 0 32 0 130 | 128 32 0 131 | 0 160 0 132 | 128 160 0 133 | 0 32 128 134 | 128 32 128 135 | 0 160 128 136 | 128 160 128 137 | 64 32 0 138 | 192 32 0 139 | 64 160 0 140 | 192 160 0 141 | 64 32 128 142 | 192 32 128 143 | 64 160 128 144 | 192 160 128 145 | 0 96 0 146 | 128 96 0 147 | 0 224 0 148 | 128 224 0 149 | 0 96 128 150 | 128 96 128 151 | 0 224 128 152 | 128 224 128 153 | 64 96 0 154 | 192 96 0 155 | 64 224 0 156 | 192 224 0 157 | 64 96 128 158 | 192 96 128 159 | 64 224 128 160 | 192 224 128 161 | 0 32 64 162 | 128 32 64 163 | 0 160 64 164 | 128 160 64 165 | 0 32 192 166 | 128 32 192 167 | 0 160 192 168 | 128 160 192 169 | 64 32 64 170 | 192 32 64 171 | 64 160 64 172 | 192 160 64 173 | 64 32 192 174 | 192 32 192 175 | 64 160 192 176 | 192 160 192 177 | 0 96 64 178 | 128 96 64 179 | 0 224 64 180 | 128 224 64 181 | 0 96 192 182 | 128 96 192 183 | 0 224 192 184 | 128 224 192 185 | 64 96 64 186 | 192 96 64 187 | 64 224 64 188 | 192 224 64 189 | 64 96 192 190 | 192 96 192 191 | 64 224 192 192 | 192 224 192 193 | 32 32 0 194 | 160 32 0 195 | 32 160 0 196 | 160 160 0 197 | 32 32 128 198 | 160 32 128 199 | 32 160 128 200 | 160 160 128 201 | 96 32 0 202 | 224 32 0 203 | 96 160 0 204 | 224 160 0 205 | 96 32 128 206 | 224 32 128 207 | 96 160 128 208 | 224 160 128 209 | 32 96 0 210 | 160 96 0 211 | 32 224 0 212 | 160 224 0 213 | 32 96 128 214 | 160 96 128 215 | 32 224 128 216 | 160 224 128 217 | 96 96 0 218 | 224 96 0 219 | 96 224 0 220 | 224 224 0 221 | 96 96 128 222 | 224 96 128 223 | 96 224 128 224 | 224 224 128 225 | 32 32 64 226 | 160 32 64 227 | 32 160 64 228 | 160 160 64 229 | 32 32 192 230 | 160 32 192 231 | 32 160 192 232 | 160 160 192 233 | 96 32 64 234 | 224 32 64 235 | 96 160 64 236 | 224 160 64 237 | 96 32 192 238 | 224 32 192 239 | 96 160 192 240 | 224 160 192 241 | 32 96 64 242 | 160 96 64 243 | 32 224 64 244 | 160 224 64 245 | 32 96 192 246 | 160 96 192 247 | 32 224 192 248 | 160 224 192 249 | 96 96 64 250 | 224 96 64 251 | 96 224 64 252 | 224 224 64 253 | 96 96 192 254 | 224 96 192 255 | 96 224 192 256 | 224 224 192 257 | -------------------------------------------------------------------------------- /doc/GAL-DeepLabv3+.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruirangerfan/GAL-DeepLabv3Plus/da613d0907ebf2908978a08f72b1a58e23caafb9/doc/GAL-DeepLabv3+.png -------------------------------------------------------------------------------- /doc/GAL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruirangerfan/GAL-DeepLabv3Plus/da613d0907ebf2908978a08f72b1a58e23caafb9/doc/GAL.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from models.base_model import BaseModel 3 | 4 | 5 | def find_model_using_name(model_name): 6 | # Given the option --model [modelname], 7 | # the file "models/modelname_model.py" 8 | # will be imported. 9 | model_filename = "models." + model_name + "_model" 10 | modellib = importlib.import_module(model_filename) 11 | 12 | # In the file, the class called ModelNameModel() will 13 | # be instantiated. It has to be a subclass of BaseModel, 14 | # and it is case-insensitive. 15 | model = None 16 | target_model_name = model_name.replace('_', '') + 'model' 17 | for name, cls in modellib.__dict__.items(): 18 | if name.lower() == target_model_name.lower() \ 19 | and issubclass(cls, BaseModel): 20 | model = cls 21 | 22 | if model is None: 23 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 24 | exit(0) 25 | 26 | return model 27 | 28 | def get_option_setter(model_name): 29 | model_class = find_model_using_name(model_name) 30 | return model_class.modify_commandline_options 31 | 32 | def create_model(opt, dataset): 33 | model = find_model_using_name(opt.model) 34 | instance = model() 35 | instance.initialize(opt, dataset) 36 | print("model [%s] was created" % (instance.name())) 37 | return instance 38 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from . import networks 5 | 6 | 7 | class BaseModel(): 8 | # modify parser to add command line options, 9 | # and also change the default values if needed 10 | @staticmethod 11 | def modify_commandline_options(parser, is_train): 12 | return parser 13 | 14 | def name(self): 15 | return 'BaseModel' 16 | 17 | def initialize(self, opt): 18 | self.opt = opt 19 | self.gpu_ids = opt.gpu_ids 20 | self.isTrain = opt.isTrain 21 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 22 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 23 | self.loss_names = [] 24 | self.model_names = [] 25 | self.visual_names = [] 26 | self.image_names = [] 27 | self.image_oriSize = [] 28 | 29 | def set_input(self, input): 30 | self.input = input 31 | 32 | def forward(self): 33 | pass 34 | 35 | # load and print networks; create schedulers 36 | def setup(self, opt, parser=None): 37 | if self.isTrain: 38 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 39 | 40 | if not self.isTrain or opt.continue_train: 41 | self.load_networks(opt.epoch) 42 | self.print_networks(opt.verbose) 43 | 44 | # make models eval mode during test time 45 | def eval(self): 46 | for name in self.model_names: 47 | if isinstance(name, str): 48 | net = getattr(self, 'net' + name) 49 | net.eval() 50 | 51 | def train(self): 52 | for name in self.model_names: 53 | if isinstance(name, str): 54 | net = getattr(self, 'net' + name) 55 | net.train() 56 | 57 | # used in test time, wrapping `forward` in no_grad() so we don't save 58 | # intermediate steps for backprop 59 | def test(self): 60 | with torch.no_grad(): 61 | self.forward() 62 | 63 | # get image names 64 | def get_image_names(self): 65 | return self.image_names 66 | 67 | def optimize_parameters(self): 68 | pass 69 | 70 | # update learning rate (called once every epoch) 71 | def update_learning_rate(self): 72 | for scheduler in self.schedulers: 73 | scheduler.step() 74 | lr = self.optimizers[0].param_groups[0]['lr'] 75 | print('learning rate = %.7f' % lr) 76 | 77 | # return visualization images. train.py will display these images in tensorboardX 78 | def get_current_visuals(self): 79 | visual_ret = OrderedDict() 80 | for name in self.visual_names: 81 | if isinstance(name, str): 82 | visual_ret[name] = getattr(self, name) 83 | return visual_ret 84 | 85 | # return traning losses/errors. train.py will print out these errors as debugging information 86 | def get_current_losses(self): 87 | errors_ret = OrderedDict() 88 | for name in self.loss_names: 89 | if isinstance(name, str): 90 | # float(...) works for both scalar tensor and float number 91 | errors_ret[name] = float(getattr(self, 'loss_' + name)) 92 | return errors_ret 93 | 94 | # save models to the disk 95 | def save_networks(self, epoch): 96 | for name in self.model_names: 97 | if isinstance(name, str): 98 | save_filename = '%s_net_%s.pth' % (epoch, name) 99 | save_path = os.path.join(self.save_dir, save_filename) 100 | net = getattr(self, 'net' + name) 101 | 102 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 103 | torch.save(net.module.cpu().state_dict(), save_path) 104 | net.cuda(self.gpu_ids[0]) 105 | else: 106 | torch.save(net.cpu().state_dict(), save_path) 107 | 108 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 109 | key = keys[i] 110 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 111 | if module.__class__.__name__.startswith('InstanceNorm') and \ 112 | (key == 'running_mean' or key == 'running_var'): 113 | if getattr(module, key) is None: 114 | state_dict.pop('.'.join(keys)) 115 | if module.__class__.__name__.startswith('InstanceNorm') and \ 116 | (key == 'num_batches_tracked'): 117 | state_dict.pop('.'.join(keys)) 118 | else: 119 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 120 | 121 | # load models from the disk 122 | def load_networks(self, epoch): 123 | for name in self.model_names: 124 | if isinstance(name, str): 125 | load_filename = '%s_net_%s.pth' % (epoch, name) 126 | load_path = os.path.join(self.save_dir, load_filename) 127 | net = getattr(self, 'net' + name) 128 | if isinstance(net, torch.nn.DataParallel): 129 | net = net.module 130 | print('loading the model from %s' % load_path) 131 | # if you are using PyTorch newer than 0.4 (e.g., built from 132 | # GitHub source), you can remove str() on self.device 133 | state_dict = torch.load(load_path, map_location=str(self.device)) 134 | if hasattr(state_dict, '_metadata'): 135 | del state_dict._metadata 136 | 137 | for key in list(state_dict.keys()): 138 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 139 | net.load_state_dict(state_dict) 140 | 141 | # print network information 142 | def print_networks(self, verbose): 143 | print('---------- Networks initialized -------------') 144 | for name in self.model_names: 145 | if isinstance(name, str): 146 | net = getattr(self, 'net' + name) 147 | num_params = 0 148 | for param in net.parameters(): 149 | num_params += param.numel() 150 | if verbose: 151 | print(net) 152 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 153 | print('-----------------------------------------------') 154 | 155 | # set requies_grad=Fasle to avoid computation 156 | def set_requires_grad(self, nets, requires_grad=False): 157 | if not isinstance(nets, list): 158 | nets = [nets] 159 | for net in nets: 160 | if net is not None: 161 | for param in net.parameters(): 162 | param.requires_grad = requires_grad 163 | -------------------------------------------------------------------------------- /models/galnet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks 4 | 5 | 6 | class GALNetModel(BaseModel): 7 | def name(self): 8 | return 'GALNet' 9 | 10 | @staticmethod 11 | def modify_commandline_options(parser, is_train=True): 12 | # changing the default values 13 | if is_train: 14 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 15 | return parser 16 | 17 | def initialize(self, opt, dataset): 18 | BaseModel.initialize(self, opt) 19 | self.isTrain = opt.isTrain 20 | # specify the training losses you want to print out. The program will call base_model.get_current_losses 21 | self.loss_names = ['segmentation'] 22 | # specify the images you want to save/display. The program will call base_model.get_current_visuals 23 | self.visual_names = ['rgb_image', 'tdisp_image', 'label', 'output'] 24 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks 25 | self.model_names = ['GALNet'] 26 | 27 | # load/define networks 28 | if opt.input == "rgb": 29 | print("Using RGB images as input") 30 | self.input_channels = 3 31 | elif opt.input == "tdisp": 32 | print("Using transformed disparity images as input") 33 | self.input_channels = 1 34 | else: 35 | raise NotImplementedError 36 | 37 | self.netGALNet = networks.define_GALNet(dataset.num_labels, gpu_ids= self.gpu_ids, input_channels= self.input_channels, use_gal=opt.gal) 38 | # define loss functions 39 | self.criterionSegmentation = networks.SegmantationLoss(class_weights=None).to(self.device) 40 | 41 | if self.isTrain: 42 | # initialize optimizers 43 | self.optimizers = [] 44 | self.optimizer = torch.optim.SGD(self.netGALNet.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay) 45 | self.optimizers.append(self.optimizer) 46 | self.set_requires_grad(self.netGALNet, True) 47 | 48 | def set_input(self, input): 49 | self.rgb_image = input['rgb_image'].to(self.device) 50 | self.tdisp_image = input['tdisp_image'].to(self.device) 51 | self.label = input['label'].to(self.device) 52 | self.image_names = input['path'] 53 | 54 | def forward(self): 55 | if self.opt.input == "rgb": 56 | self.output = self.netGALNet(self.rgb_image) 57 | elif self.opt.input == "tdisp": 58 | self.output = self.netGALNet(self.tdisp_image) 59 | else: 60 | raise NotImplementedError 61 | 62 | def get_loss(self): 63 | self.loss_segmentation = self.criterionSegmentation(self.output, self.label) 64 | 65 | def backward(self): 66 | self.loss_segmentation.backward() 67 | 68 | def optimize_parameters(self): 69 | self.forward() 70 | self.optimizer.zero_grad() 71 | self.get_loss() 72 | self.backward() 73 | self.optimizer.step() 74 | -------------------------------------------------------------------------------- /models/include/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruirangerfan/GAL-DeepLabv3Plus/da613d0907ebf2908978a08f72b1a58e23caafb9/models/include/__init__.py -------------------------------------------------------------------------------- /models/include/deeplabv3plus_inc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruirangerfan/GAL-DeepLabv3Plus/da613d0907ebf2908978a08f72b1a58e23caafb9/models/include/deeplabv3plus_inc/__init__.py -------------------------------------------------------------------------------- /models/include/deeplabv3plus_inc/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruirangerfan/GAL-DeepLabv3Plus/da613d0907ebf2908978a08f72b1a58e23caafb9/models/include/deeplabv3plus_inc/modeling/__init__.py -------------------------------------------------------------------------------- /models/include/deeplabv3plus_inc/modeling/aspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 9 | super(_ASPPModule, self).__init__() 10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 11 | stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = BatchNorm(planes) 13 | self.relu = nn.ReLU() 14 | 15 | self._init_weight() 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | 21 | return self.relu(x) 22 | 23 | def _init_weight(self): 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | torch.nn.init.kaiming_normal_(m.weight) 27 | elif isinstance(m, SynchronizedBatchNorm2d): 28 | m.weight.data.fill_(1) 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.BatchNorm2d): 31 | m.weight.data.fill_(1) 32 | m.bias.data.zero_() 33 | 34 | class ASPP(nn.Module): 35 | def __init__(self, backbone, output_stride, BatchNorm): 36 | super(ASPP, self).__init__() 37 | if backbone == 'drn': 38 | inplanes = 512 39 | elif backbone == 'mobilenet': 40 | inplanes = 320 41 | else: 42 | inplanes = 2048 43 | if output_stride == 16: 44 | dilations = [1, 6, 12, 18] 45 | elif output_stride == 8: 46 | dilations = [1, 12, 24, 36] 47 | else: 48 | raise NotImplementedError 49 | 50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 54 | 55 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 56 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 57 | BatchNorm(256), 58 | nn.ReLU()) 59 | 60 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 61 | self.bn1 = BatchNorm(256) 62 | self.relu = nn.ReLU() 63 | self.dropout = nn.Dropout(0.5) 64 | self._init_weight() 65 | 66 | def forward(self, x): 67 | x1 = self.aspp1(x) 68 | x2 = self.aspp2(x) 69 | x3 = self.aspp3(x) 70 | x4 = self.aspp4(x) 71 | x5 = self.global_avg_pool(x) 72 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 73 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 74 | 75 | x = self.conv1(x) 76 | x = self.bn1(x) 77 | x = self.relu(x) 78 | 79 | return self.dropout(x) 80 | 81 | def _init_weight(self): 82 | for m in self.modules(): 83 | if isinstance(m, nn.Conv2d): 84 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 85 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 86 | torch.nn.init.kaiming_normal_(m.weight) 87 | elif isinstance(m, SynchronizedBatchNorm2d): 88 | m.weight.data.fill_(1) 89 | m.bias.data.zero_() 90 | elif isinstance(m, nn.BatchNorm2d): 91 | m.weight.data.fill_(1) 92 | m.bias.data.zero_() 93 | 94 | 95 | def build_aspp(backbone, output_stride, BatchNorm): 96 | return ASPP(backbone, output_stride, BatchNorm) -------------------------------------------------------------------------------- /models/include/deeplabv3plus_inc/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruirangerfan/GAL-DeepLabv3Plus/da613d0907ebf2908978a08f72b1a58e23caafb9/models/include/deeplabv3plus_inc/modeling/backbone/__init__.py -------------------------------------------------------------------------------- /models/include/deeplabv3plus_inc/modeling/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | from models.include.deeplabv3plus_inc.modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | 8 | class Bottleneck(nn.Module): 9 | expansion = 4 10 | 11 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 12 | super(Bottleneck, self).__init__() 13 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 14 | self.bn1 = BatchNorm(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 16 | dilation=dilation, padding=dilation, bias=False) 17 | self.bn2 = BatchNorm(planes) 18 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 19 | self.bn3 = BatchNorm(planes * 4) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.downsample = downsample 22 | self.stride = stride 23 | self.dilation = dilation 24 | 25 | def forward(self, x): 26 | residual = x 27 | 28 | out = self.conv1(x) 29 | out = self.bn1(out) 30 | out = self.relu(out) 31 | 32 | out = self.conv2(out) 33 | out = self.bn2(out) 34 | out = self.relu(out) 35 | 36 | out = self.conv3(out) 37 | out = self.bn3(out) 38 | 39 | if self.downsample is not None: 40 | residual = self.downsample(x) 41 | 42 | out += residual 43 | out = self.relu(out) 44 | 45 | return out 46 | 47 | class ResNet(nn.Module): 48 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True, num_ch=3): 49 | self.inplanes = 64 50 | super(ResNet, self).__init__() 51 | 52 | self.num_ch = num_ch 53 | 54 | blocks = [1, 2, 4] 55 | if output_stride == 16: 56 | strides = [1, 2, 2, 1] 57 | dilations = [1, 1, 1, 2] 58 | elif output_stride == 8: 59 | strides = [1, 2, 1, 1] 60 | dilations = [1, 1, 2, 4] 61 | else: 62 | raise NotImplementedError 63 | 64 | # Modules 65 | self.conv1 = nn.Conv2d(num_ch, 64, kernel_size=7, stride=2, padding=3, bias=False) 66 | self.bn1 = BatchNorm(64) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 69 | 70 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 71 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 72 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 73 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 74 | self._init_weight() 75 | 76 | if pretrained: 77 | self._load_pretrained_model() 78 | 79 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 80 | downsample = None 81 | if stride != 1 or self.inplanes != planes * block.expansion: 82 | downsample = nn.Sequential( 83 | nn.Conv2d(self.inplanes, planes * block.expansion, 84 | kernel_size=1, stride=stride, bias=False), 85 | BatchNorm(planes * block.expansion), 86 | ) 87 | 88 | layers = [] 89 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 90 | self.inplanes = planes * block.expansion 91 | for i in range(1, blocks): 92 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 93 | 94 | return nn.Sequential(*layers) 95 | 96 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 97 | downsample = None 98 | if stride != 1 or self.inplanes != planes * block.expansion: 99 | downsample = nn.Sequential( 100 | nn.Conv2d(self.inplanes, planes * block.expansion, 101 | kernel_size=1, stride=stride, bias=False), 102 | BatchNorm(planes * block.expansion), 103 | ) 104 | 105 | layers = [] 106 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 107 | downsample=downsample, BatchNorm=BatchNorm)) 108 | self.inplanes = planes * block.expansion 109 | for i in range(1, len(blocks)): 110 | layers.append(block(self.inplanes, planes, stride=1, 111 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 112 | 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, input): 116 | x = self.conv1(input) 117 | x = self.bn1(x) 118 | x = self.relu(x) 119 | x = self.maxpool(x) 120 | 121 | x = self.layer1(x) 122 | low_level_feat = x 123 | x = self.layer2(x) 124 | x = self.layer3(x) 125 | x = self.layer4(x) 126 | return x, low_level_feat 127 | 128 | def _init_weight(self): 129 | for m in self.modules(): 130 | if isinstance(m, nn.Conv2d): 131 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 132 | m.weight.data.normal_(0, math.sqrt(2. / n)) 133 | elif isinstance(m, SynchronizedBatchNorm2d): 134 | m.weight.data.fill_(1) 135 | m.bias.data.zero_() 136 | elif isinstance(m, nn.BatchNorm2d): 137 | m.weight.data.fill_(1) 138 | m.bias.data.zero_() 139 | 140 | def _load_pretrained_model(self): 141 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth') 142 | model_dict = {} 143 | state_dict = self.state_dict() 144 | for k, v in pretrain_dict.items(): 145 | if k == "conv1.weight" and self.num_ch == 1: 146 | model_dict[k] = torch.unsqueeze(torch.mean(v, dim=1), dim=1) 147 | continue 148 | if k in state_dict: 149 | model_dict[k] = v 150 | state_dict.update(model_dict) 151 | self.load_state_dict(state_dict) 152 | 153 | 154 | def ResNet50(output_stride, BatchNorm, pretrained=True, num_ch=1): 155 | """Constructs a ResNet-50 model. 156 | Args: 157 | pretrained (bool): If True, returns a model pre-trained on ImageNet 158 | """ 159 | model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, pretrained=pretrained, num_ch=num_ch) 160 | return model 161 | 162 | def ResNet101(output_stride, BatchNorm, pretrained=True, num_ch=1): 163 | """Constructs a ResNet-101 model. 164 | Args: 165 | pretrained (bool): If True, returns a model pre-trained on ImageNet 166 | """ 167 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained, num_ch=num_ch) 168 | return model 169 | -------------------------------------------------------------------------------- /models/include/deeplabv3plus_inc/modeling/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, num_classes, backbone, BatchNorm): 9 | super(Decoder, self).__init__() 10 | if backbone == 'resnet' or backbone == 'drn': 11 | low_level_inplanes = 256 12 | elif backbone == 'xception': 13 | low_level_inplanes = 128 14 | elif backbone == 'mobilenet': 15 | low_level_inplanes = 24 16 | else: 17 | raise NotImplementedError 18 | 19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 20 | self.bn1 = BatchNorm(48) 21 | self.relu = nn.ReLU() 22 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 23 | BatchNorm(256), 24 | nn.ReLU(), 25 | nn.Dropout(0.5), 26 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 27 | BatchNorm(256), 28 | nn.ReLU(), 29 | nn.Dropout(0.1), 30 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 31 | self._init_weight() 32 | 33 | def forward(self, x, low_level_feat): 34 | low_level_feat = self.conv1(low_level_feat) 35 | low_level_feat = self.bn1(low_level_feat) 36 | low_level_feat = self.relu(low_level_feat) 37 | 38 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 39 | x = torch.cat((x, low_level_feat), dim=1) 40 | x = self.last_conv(x) 41 | 42 | return x 43 | 44 | def _init_weight(self): 45 | for m in self.modules(): 46 | if isinstance(m, nn.Conv2d): 47 | torch.nn.init.kaiming_normal_(m.weight) 48 | elif isinstance(m, SynchronizedBatchNorm2d): 49 | m.weight.data.fill_(1) 50 | m.bias.data.zero_() 51 | elif isinstance(m, nn.BatchNorm2d): 52 | m.weight.data.fill_(1) 53 | m.bias.data.zero_() 54 | 55 | 56 | def build_decoder(num_classes, backbone, BatchNorm): 57 | return Decoder(num_classes, backbone, BatchNorm) 58 | -------------------------------------------------------------------------------- /models/include/deeplabv3plus_inc/modeling/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /models/include/deeplabv3plus_inc/modeling/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /models/include/deeplabv3plus_inc/modeling/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from multiprocessing import Queue 12 | #import queue 13 | import collections 14 | import threading 15 | 16 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 17 | 18 | 19 | class FutureResult(object): 20 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 21 | 22 | def __init__(self): 23 | self._result = None 24 | self._lock = threading.Lock() 25 | self._cond = threading.Condition(self._lock) 26 | 27 | def put(self, result): 28 | with self._lock: 29 | assert self._result is None, 'Previous result has\'t been fetched.' 30 | self._result = result 31 | self._cond.notify() 32 | 33 | def get(self): 34 | with self._lock: 35 | if self._result is None: 36 | self._cond.wait() 37 | 38 | res = self._result 39 | self._result = None 40 | return res 41 | 42 | 43 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 44 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 45 | 46 | 47 | class SlavePipe(_SlavePipeBase): 48 | """Pipe for master-slave communication.""" 49 | 50 | def run_slave(self, msg): 51 | self.queue.put((self.identifier, msg)) 52 | ret = self.result.get() 53 | self.queue.put(True) 54 | return ret 55 | 56 | 57 | class SyncMaster(object): 58 | """An abstract `SyncMaster` object. 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | Args: 70 | master_callback: a callback to be invoked after having collected messages from slave devices. 71 | """ 72 | self._master_callback = master_callback 73 | self._queue = Queue() 74 | self._registry = collections.OrderedDict() 75 | self._activated = False 76 | 77 | def __getstate__(self): 78 | return {'master_callback': self._master_callback} 79 | 80 | def __setstate__(self, state): 81 | self.__init__(state['master_callback']) 82 | 83 | def register_slave(self, identifier): 84 | """ 85 | Register an slave device. 86 | Args: 87 | identifier: an identifier, usually is the device id. 88 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 89 | """ 90 | if self._activated: 91 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 92 | self._activated = False 93 | self._registry.clear() 94 | future = FutureResult() 95 | self._registry[identifier] = _MasterRegistry(future) 96 | return SlavePipe(identifier, self._queue, future) 97 | 98 | def run_master(self, master_msg): 99 | """ 100 | Main entry for the master device in each forward pass. 101 | The messages were first collected from each devices (including the master device), and then 102 | an callback will be invoked to compute the message to be sent back to each devices 103 | (including the master device). 104 | Args: 105 | master_msg: the message that the master want to send to itself. This will be placed as the first 106 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 107 | Returns: the message to be sent back to the master device. 108 | """ 109 | self._activated = True 110 | 111 | intermediates = [(0, master_msg)] 112 | for i in range(self.nr_slaves): 113 | intermediates.append(self._queue.get()) 114 | 115 | results = self._master_callback(intermediates) 116 | assert results[0][0] == 0, 'The first result should belongs to the master.' 117 | 118 | for i, res in results: 119 | if i == 0: 120 | continue 121 | self._registry[i].result.put(res) 122 | 123 | for i in range(self.nr_slaves): 124 | assert self._queue.get() is True 125 | 126 | return results[0][1] 127 | 128 | @property 129 | def nr_slaves(self): 130 | return len(self._registry) 131 | -------------------------------------------------------------------------------- /models/include/deeplabv3plus_inc/modeling/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torch.optim import lr_scheduler 5 | import torch.nn.functional as F 6 | 7 | from models.include.deeplabv3plus_inc.modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d 8 | from models.include.deeplabv3plus_inc.modeling.aspp import build_aspp 9 | from models.include.deeplabv3plus_inc.modeling.decoder import build_decoder 10 | from models.include.deeplabv3plus_inc.modeling.backbone import resnet 11 | 12 | 13 | ### help functions ### 14 | def get_scheduler(optimizer, opt): 15 | if opt.lr_policy == 'lambda': 16 | lambda_rule = lambda epoch: opt.lr_gamma ** ((epoch+1) // opt.lr_decay_epochs) 17 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 18 | elif opt.lr_policy == 'step': 19 | scheduler = lr_scheduler.StepLR(optimizer,step_size=opt.lr_decay_iters, gamma=0.1) 20 | else: 21 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 22 | return scheduler 23 | 24 | def init_net(net, gpu_ids=[]): 25 | if len(gpu_ids) > 0: 26 | assert(torch.cuda.is_available()) 27 | net.to(gpu_ids[0]) 28 | net = torch.nn.DataParallel(net, gpu_ids) 29 | return net 30 | 31 | def define_GALNet(num_labels, gpu_ids=[], input_channels=1, use_gal=True): 32 | net = GALDeepLabV3Plus(n_class=num_labels, input_channels=input_channels, use_gal=use_gal) 33 | return init_net(net, gpu_ids) 34 | 35 | 36 | # Ref: https://github.com/jfzhang95/pytorch-deeplab-xception 37 | def build_backbone(backbone, output_stride, BatchNorm, input_channels): 38 | if backbone == 'resnet': 39 | return resnet.ResNet50(output_stride, BatchNorm, num_ch=input_channels) 40 | else: 41 | raise NotImplementedError 42 | 43 | class GALDeepLabV3Plus(nn.Module): 44 | def __init__(self, n_class=2, backbone='resnet', output_stride=16, sync_bn=True, freeze_bn=False, input_channels=1, use_gal=True): 45 | super(GALDeepLabV3Plus, self).__init__() 46 | 47 | if backbone == 'drn': 48 | output_stride = 8 49 | 50 | if sync_bn == True: 51 | BatchNorm = SynchronizedBatchNorm2d 52 | else: 53 | BatchNorm = nn.BatchNorm2d 54 | 55 | self.backbone = build_backbone(backbone, output_stride, BatchNorm, input_channels) 56 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 57 | self.decoder = build_decoder(n_class, backbone, BatchNorm) 58 | 59 | self.use_gal = False 60 | if use_gal: 61 | print("Using GAL") 62 | self.use_gal = True 63 | self.gal = GAL(sync_bn=sync_bn, input_channels=2048) 64 | 65 | if freeze_bn: 66 | self.freeze_bn() 67 | 68 | def forward(self, input): 69 | input = input.float() 70 | 71 | x, low_level_feat = self.backbone(input) 72 | 73 | if self.use_gal: 74 | x = self.gal(x) 75 | 76 | x = self.aspp(x) 77 | 78 | x = self.decoder(x, low_level_feat) 79 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 80 | 81 | return x 82 | 83 | def freeze_bn(self): 84 | for m in self.modules(): 85 | if isinstance(m, SynchronizedBatchNorm2d): 86 | m.eval() 87 | elif isinstance(m, nn.BatchNorm2d): 88 | m.eval() 89 | 90 | 91 | class GAL(nn.Module): 92 | def __init__(self, sync_bn=True, input_channels=2048): 93 | super(GAL, self).__init__() 94 | self.input_channels = input_channels 95 | if sync_bn == True: 96 | BatchNorm1d = SynchronizedBatchNorm1d 97 | BatchNorm2d = SynchronizedBatchNorm2d 98 | else: 99 | BatchNorm1d = nn.BatchNorm1d 100 | BatchNorm2d = nn.BatchNorm2d 101 | 102 | self.edge_aggregation_func = nn.Sequential( 103 | nn.Linear(4, 1), 104 | BatchNorm1d(1), 105 | nn.ReLU(inplace=True), 106 | ) 107 | self.vertex_update_func = nn.Sequential( 108 | nn.Linear(2 * input_channels, input_channels // 2), 109 | BatchNorm1d(input_channels // 2), 110 | nn.ReLU(inplace=True), 111 | ) 112 | 113 | self.edge_update_func = nn.Sequential( 114 | nn.Linear(2 * input_channels, input_channels // 2), 115 | BatchNorm1d(input_channels // 2), 116 | nn.ReLU(inplace=True), 117 | ) 118 | self.update_edge_reduce_func = nn.Sequential( 119 | nn.Linear(4, 1), 120 | BatchNorm1d(1), 121 | nn.ReLU(inplace=True), 122 | ) 123 | 124 | self.final_aggregation_layer = nn.Sequential( 125 | nn.Conv2d(input_channels + input_channels // 2, input_channels, kernel_size=1, stride=1, padding=0, bias=False), 126 | BatchNorm2d(input_channels), 127 | nn.ReLU(inplace=True), 128 | ) 129 | 130 | self._init_weight() 131 | 132 | def forward(self, input): 133 | x = input 134 | B, C, H, W = x.size() 135 | 136 | vertex = input 137 | edge = torch.stack( 138 | ( 139 | torch.cat((input[:,:,-1:], input[:,:,:-1]), dim=2), 140 | torch.cat((input[:,:,1:], input[:,:,:1]), dim=2), 141 | torch.cat((input[:,:,:,-1:], input[:,:,:,:-1]), dim=3), 142 | torch.cat((input[:,:,:,1:], input[:,:,:,:1]), dim=3) 143 | ), dim=-1 144 | ) * input.unsqueeze(dim=-1) 145 | 146 | aggregated_edge = self.edge_aggregation_func( 147 | edge.reshape(-1, 4) 148 | ).reshape((B, C, H, W)) 149 | cat_feature_for_vertex = torch.cat((vertex, aggregated_edge), dim=1) 150 | update_vertex = self.vertex_update_func( 151 | cat_feature_for_vertex.permute(0, 2, 3, 1).reshape((-1, 2 * self.input_channels)) 152 | ).reshape((B, H, W, self.input_channels // 2)).permute(0, 3, 1, 2) 153 | 154 | cat_feature_for_edge = torch.cat( 155 | ( 156 | torch.stack((vertex, vertex, vertex, vertex), dim=-1), 157 | edge 158 | ), dim=1 159 | ).permute(0, 2, 3, 4, 1).reshape((-1, 2 * self.input_channels)) 160 | update_edge = self.edge_update_func(cat_feature_for_edge).reshape((B, H, W, 4, C//2)).permute(0, 4, 1, 2, 3).reshape((-1, 4)) 161 | update_edge_converted = self.update_edge_reduce_func(update_edge).reshape((B, C//2, H, W)) 162 | 163 | update_feature = update_vertex * update_edge_converted 164 | output = self.final_aggregation_layer( 165 | torch.cat((x, update_feature), dim=1) 166 | ) 167 | 168 | return output 169 | 170 | def _init_weight(self): 171 | for m in self.modules(): 172 | if isinstance(m, nn.Conv2d): 173 | torch.nn.init.kaiming_normal_(m.weight) 174 | elif isinstance(m, nn.Linear): 175 | torch.nn.init.kaiming_normal_(m.weight) 176 | elif isinstance(m, SynchronizedBatchNorm1d): 177 | m.weight.data.fill_(1) 178 | m.bias.data.zero_() 179 | elif isinstance(m, nn.BatchNorm1d): 180 | m.weight.data.fill_(1) 181 | m.bias.data.zero_() 182 | elif isinstance(m, SynchronizedBatchNorm2d): 183 | m.weight.data.fill_(1) 184 | m.bias.data.zero_() 185 | elif isinstance(m, nn.BatchNorm2d): 186 | m.weight.data.fill_(1) 187 | m.bias.data.zero_() 188 | 189 | 190 | class SegmantationLoss(nn.Module): 191 | def __init__(self, class_weights=None): 192 | super(SegmantationLoss, self).__init__() 193 | self.loss = nn.CrossEntropyLoss(weight=class_weights) 194 | def __call__(self, output, target): 195 | return self.loss(output, target) 196 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruirangerfan/GAL-DeepLabv3Plus/da613d0907ebf2908978a08f72b1a58e23caafb9/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | def __init__(self): 11 | self.initialized = False 12 | 13 | def initialize(self, parser): 14 | parser.add_argument('--batch_size', type=int, default=2, help='input batch size') 15 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 16 | parser.add_argument('--name', type=str, default='pothole', help='name of the experiment. It decides where to store samples and models') 17 | parser.add_argument('--input', type=str, default='tdisp', help='chooses input images') 18 | parser.add_argument('--dataset', type=str, default='pothole', help='chooses which dataset to load.') 19 | parser.add_argument('--model', type=str, default='galnet', help='chooses which model to use.') 20 | parser.add_argument('--gal', action='store_true', help='if true, use gal') 21 | parser.add_argument('--epoch', type=str, default='best', help='chooses which epoch to load') 22 | parser.add_argument('--num_threads', default=2, type=int, help='# threads for loading data') 23 | parser.add_argument('--checkpoints_dir', type=str, default='./runs', help='models and records are saved here') 24 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 25 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 26 | parser.add_argument('--seed', type=int, default=0, help='seed for random generators') 27 | self.initialized = True 28 | return parser 29 | 30 | def gather_options(self): 31 | # initialize parser with basic options 32 | if not self.initialized: 33 | parser = argparse.ArgumentParser( 34 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 35 | parser = self.initialize(parser) 36 | 37 | # get the basic options 38 | opt, _ = parser.parse_known_args() 39 | 40 | # modify model-related parser options 41 | model_name = opt.model 42 | model_option_setter = models.get_option_setter(model_name) 43 | parser = model_option_setter(parser, self.isTrain) 44 | opt, _ = parser.parse_known_args() # parse again with the new defaults 45 | 46 | # modify dataset-related parser options 47 | dataset_name = opt.dataset 48 | dataset_option_setter = data.get_option_setter(dataset_name) 49 | parser = dataset_option_setter(parser, self.isTrain) 50 | 51 | self.parser = parser 52 | 53 | return parser.parse_args() 54 | 55 | def print_options(self, opt): 56 | message = '' 57 | message += '----------------- Options ---------------\n' 58 | for k, v in sorted(vars(opt).items()): 59 | comment = '' 60 | default = self.parser.get_default(k) 61 | if v != default: 62 | comment = '\t[default: %s]' % str(default) 63 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 64 | message += '----------------- End -------------------' 65 | print(message) 66 | 67 | # save to the disk 68 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 69 | util.mkdirs(expr_dir) 70 | file_name = os.path.join(expr_dir, 'opt.txt') 71 | with open(file_name, 'wt') as opt_file: 72 | opt_file.write(message) 73 | opt_file.write('\n') 74 | 75 | def parse(self): 76 | opt = self.gather_options() 77 | opt.isTrain = self.isTrain # train or test 78 | 79 | self.print_options(opt) 80 | 81 | # set gpu ids 82 | str_ids = opt.gpu_ids.split(',') 83 | opt.gpu_ids = [] 84 | for str_id in str_ids: 85 | id = int(str_id) 86 | if id >= 0: 87 | opt.gpu_ids.append(id) 88 | if len(opt.gpu_ids) > 0: 89 | torch.cuda.set_device(opt.gpu_ids[0]) 90 | 91 | self.opt = opt 92 | return self.opt 93 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--results_dir', type=str, default='./testresults/', help='saves results here.') 8 | parser.add_argument('--phase', type=str, default='test', help='train, val, test') 9 | self.isTrain = False 10 | return parser 11 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import Required 2 | from .base_options import BaseOptions 3 | 4 | 5 | class TrainOptions(BaseOptions): 6 | def initialize(self, parser): 7 | parser = BaseOptions.initialize(self, parser) 8 | parser.add_argument('--print_freq', type=int, default=10, help='frequency of showing training results on console') 9 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 10 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count') 11 | parser.add_argument('--phase', type=str, default='train', help='train, val, test') 12 | parser.add_argument('--nepoch', type=int, default=1000, help='maximum epochs') 13 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 14 | parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for optimizer') 15 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum factor for SGD') 16 | parser.add_argument('--weight_decay', type=float, default=0.0005, help='momentum factor for optimizer') 17 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine') 18 | parser.add_argument('--lr_decay_iters', type=int, default=5000000, help='multiply by a gamma every lr_decay_iters iterations') 19 | parser.add_argument('--lr_decay_epochs', type=int, default=25, help='multiply by a gamma every lr_decay_epoch epochs') 20 | parser.add_argument('--lr_gamma', type=float, default=0.9, help='gamma factor for lr_scheduler') 21 | self.isTrain = True 22 | return parser 23 | -------------------------------------------------------------------------------- /scripts/test_gal.sh: -------------------------------------------------------------------------------- 1 | python3 test.py --dataset pothole --model galnet --input tdisp --name tdisp_gal --gal --epoch best -------------------------------------------------------------------------------- /scripts/train_gal.sh: -------------------------------------------------------------------------------- 1 | python3 train.py --dataset pothole --model galnet --input tdisp --name tdisp_gal --gal -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from options.test_options import TestOptions 3 | from data import CreateDataLoader 4 | from models import create_model 5 | from util.util import confusion_matrix, getScores, save_images 6 | import torch 7 | import numpy as np 8 | 9 | 10 | if __name__ == '__main__': 11 | opt = TestOptions().parse() 12 | opt.num_threads = 1 13 | opt.batch_size = 1 14 | opt.serial_batches = True # no shuffle 15 | opt.isTrain = False 16 | 17 | save_dir = os.path.join(opt.results_dir, opt.name, opt.phase + '_' + opt.epoch) 18 | if not os.path.exists(save_dir): 19 | os.makedirs(save_dir) 20 | 21 | data_loader = CreateDataLoader(opt) 22 | dataset = data_loader.load_data() 23 | model = create_model(opt, dataset.dataset) 24 | model.setup(opt) 25 | model.eval() 26 | 27 | test_loss_iter = [] 28 | epoch_iter = 0 29 | conf_mat = np.zeros((dataset.dataset.num_labels, dataset.dataset.num_labels), dtype=np.float32) 30 | with torch.no_grad(): 31 | for i, data in enumerate(dataset): 32 | model.set_input(data) 33 | model.forward() 34 | model.get_loss() 35 | epoch_iter += opt.batch_size 36 | gt = model.label.cpu().int().numpy() 37 | _, pred = torch.max(model.output.data.cpu(), 1) 38 | pred = pred.float().detach().int().numpy() 39 | save_images(save_dir, model.get_current_visuals(), model.get_image_names()) 40 | conf_mat += confusion_matrix(gt, pred, dataset.dataset.num_labels) 41 | 42 | test_loss_iter.append(model.loss_segmentation) 43 | print('Epoch {0:}, iters: {1:}/{2:}, loss: {3:.3f} '.format(opt.epoch, 44 | epoch_iter, 45 | len(dataset) * opt.batch_size, 46 | test_loss_iter[-1]), end='\r') 47 | 48 | avg_test_loss = torch.mean(torch.stack(test_loss_iter)) 49 | print ('Epoch {0:} test loss: {1:.3f} '.format(opt.epoch, avg_test_loss)) 50 | globalacc, pre, recall, F_score, iou = getScores(conf_mat) 51 | print ('Epoch {0:} glob acc : {1:.3f}, pre : {2:.3f}, recall : {3:.3f}, F_score : {4:.3f}, IoU : {5:.3f}'.format(opt.epoch, globalacc, pre, recall, F_score, iou)) 52 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from data import CreateDataLoader 4 | from models import create_model 5 | from util.util import confusion_matrix, getScores, tensor2labelim, tensor2im, print_current_losses 6 | import numpy as np 7 | import random 8 | import torch 9 | import os 10 | from tensorboardX import SummaryWriter 11 | 12 | 13 | if __name__ == '__main__': 14 | train_opt = TrainOptions().parse() 15 | 16 | np.random.seed(train_opt.seed) 17 | random.seed(train_opt.seed) 18 | torch.manual_seed(train_opt.seed) 19 | torch.cuda.manual_seed(train_opt.seed) 20 | 21 | train_data_loader = CreateDataLoader(train_opt) 22 | train_dataset = train_data_loader.load_data() 23 | train_dataset_size = len(train_data_loader) 24 | print('#training images = %d' % train_dataset_size) 25 | 26 | valid_opt = TrainOptions().parse() 27 | valid_opt.phase = 'val' 28 | valid_opt.batch_size = 1 29 | valid_opt.num_threads = 1 30 | valid_opt.serial_batches = True 31 | valid_opt.isTrain = False 32 | valid_data_loader = CreateDataLoader(valid_opt) 33 | valid_dataset = valid_data_loader.load_data() 34 | valid_dataset_size = len(valid_data_loader) 35 | print('#validation images = %d' % valid_dataset_size) 36 | 37 | writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name)) 38 | 39 | model = create_model(train_opt, train_dataset.dataset) 40 | model.setup(train_opt) 41 | total_steps = 0 42 | tfcount = 0 43 | iou_max = 0 44 | for epoch in range(train_opt.epoch_count, train_opt.nepoch + 1): 45 | ### Training on the training set ### 46 | model.train() 47 | epoch_start_time = time.time() 48 | iter_data_time = time.time() 49 | epoch_iter = 0 50 | train_loss_iter = [] 51 | for i, data in enumerate(train_dataset): 52 | iter_start_time = time.time() 53 | if total_steps % train_opt.print_freq == 0: 54 | t_data = iter_start_time - iter_data_time 55 | total_steps += train_opt.batch_size 56 | epoch_iter += train_opt.batch_size 57 | model.set_input(data) 58 | model.optimize_parameters() 59 | 60 | if total_steps % train_opt.print_freq == 0: 61 | tfcount = tfcount + 1 62 | losses = model.get_current_losses() 63 | train_loss_iter.append(losses["segmentation"]) 64 | t = (time.time() - iter_start_time) / train_opt.batch_size 65 | print_current_losses(epoch, epoch_iter, losses, t, t_data) 66 | # There are several whole_loss values shown in tensorboard in one epoch, 67 | # to help better see the optimization phase 68 | writer.add_scalar('train/whole_loss', losses["segmentation"], tfcount) 69 | 70 | iter_data_time = time.time() 71 | 72 | mean_loss = np.mean(train_loss_iter) 73 | # One average training loss value in tensorboard in one epoch 74 | writer.add_scalar('train/mean_loss', mean_loss, epoch) 75 | 76 | palet_file = 'datasets/palette.txt' 77 | impalette = list(np.genfromtxt(palet_file,dtype=np.uint8).reshape(3*256)) 78 | tempDict = model.get_current_visuals() 79 | rgb = tensor2im(tempDict['rgb_image']) 80 | tdisp = tensor2im(tempDict['tdisp_image']) 81 | label = tensor2labelim(tempDict['label'], impalette) 82 | output = tensor2labelim(tempDict['output'], impalette) 83 | image_numpy = np.concatenate((rgb, tdisp, label, output), axis=1) 84 | image_numpy = image_numpy.astype(np.float32) / 255 85 | writer.add_image('Epoch' + str(epoch), image_numpy, dataformats='HWC') # show training images in tensorboard 86 | 87 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, train_opt.nepoch, time.time() - epoch_start_time)) 88 | model.update_learning_rate() 89 | 90 | ### Evaluation on the validation set ### 91 | model.eval() 92 | valid_loss_iter = [] 93 | epoch_iter = 0 94 | conf_mat = np.zeros((valid_dataset.dataset.num_labels, valid_dataset.dataset.num_labels), dtype=np.float32) 95 | with torch.no_grad(): 96 | for i, data in enumerate(valid_dataset): 97 | model.set_input(data) 98 | model.forward() 99 | model.get_loss() 100 | epoch_iter += valid_opt.batch_size 101 | gt = model.label.cpu().int().numpy() 102 | _, pred = torch.max(model.output.data.cpu(), 1) 103 | pred = pred.float().detach().int().numpy() 104 | 105 | conf_mat += confusion_matrix(gt, pred, valid_dataset.dataset.num_labels) 106 | losses = model.get_current_losses() 107 | valid_loss_iter.append(model.loss_segmentation) 108 | print('valid epoch {0:}, iters: {1:}/{2:} '.format(epoch, epoch_iter, len(valid_dataset) * valid_opt.batch_size), end='\r') 109 | 110 | avg_valid_loss = torch.mean(torch.stack(valid_loss_iter)) 111 | globalacc, pre, recall, F_score, iou = getScores(conf_mat) 112 | 113 | # Record performance on the validation set 114 | writer.add_scalar('valid/loss', avg_valid_loss, epoch) 115 | writer.add_scalar('valid/global_acc', globalacc, epoch) 116 | writer.add_scalar('valid/pre', pre, epoch) 117 | writer.add_scalar('valid/recall', recall, epoch) 118 | writer.add_scalar('valid/F_score', F_score, epoch) 119 | writer.add_scalar('valid/iou', iou, epoch) 120 | 121 | # Save the best model according to the F-score, and record corresponding epoch number in tensorboard 122 | if iou > iou_max: 123 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 124 | model.save_networks('best') 125 | iou_max = iou 126 | writer.add_text('best model', str(epoch)) 127 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruirangerfan/GAL-DeepLabv3Plus/da613d0907ebf2908978a08f72b1a58e23caafb9/util/__init__.py -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | import cv2 7 | 8 | 9 | def save_images(save_dir, visuals, image_name): 10 | """save images to disk""" 11 | image_name = image_name[0] 12 | palet_file = 'datasets/palette.txt' 13 | impalette = list(np.genfromtxt(palet_file, dtype=np.uint8).reshape(3*256)) 14 | 15 | for label, im_data in visuals.items(): 16 | if label == 'output': 17 | im = tensor2labelim(im_data, impalette) 18 | cv2.imwrite(os.path.join(save_dir, image_name), cv2.cvtColor(im, cv2.COLOR_RGB2BGR)) 19 | 20 | def tensor2im(input_image, imtype=np.uint8): 21 | """Converts a image Tensor into an image array (numpy)""" 22 | if isinstance(input_image, torch.Tensor): 23 | image_tensor = input_image.data 24 | else: 25 | return input_image 26 | image_numpy = image_tensor[0].cpu().float().numpy() 27 | if image_numpy.shape[0] == 1: 28 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 29 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)))* 255.0 30 | return image_numpy.astype(imtype) 31 | 32 | def tensor2labelim(label_tensor, impalette, imtype=np.uint8): 33 | """Converts a label Tensor into an image array (numpy), 34 | we use a palette to color the label images""" 35 | if len(label_tensor.shape) == 4: 36 | _, label_tensor = torch.max(label_tensor.data.cpu(), 1) 37 | 38 | label_numpy = label_tensor[0].cpu().float().detach().numpy() 39 | label_image = Image.fromarray(label_numpy.astype(np.uint8)) 40 | label_image = label_image.convert("P") 41 | label_image.putpalette(impalette) 42 | label_image = label_image.convert("RGB") 43 | return np.array(label_image).astype(imtype) 44 | 45 | def print_current_losses(epoch, i, losses, t, t_data): 46 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data) 47 | for k, v in losses.items(): 48 | message += '%s: %.3f ' % (k, v) 49 | print(message) 50 | 51 | 52 | def mkdirs(paths): 53 | if isinstance(paths, list) and not isinstance(paths, str): 54 | for path in paths: 55 | mkdir(path) 56 | else: 57 | mkdir(paths) 58 | 59 | def mkdir(path): 60 | if not os.path.exists(path): 61 | os.makedirs(path) 62 | 63 | 64 | def confusion_matrix(x, y, n, ignore_label=None, mask=None): 65 | if mask is None: 66 | mask = np.ones_like(x) == 1 67 | k = (x >= 0) & (y < n) & (x != ignore_label) & (mask.astype(np.bool)) 68 | return np.bincount(n * x[k].astype(int) + y[k], minlength=n**2).reshape(n, n) 69 | 70 | def getScores(conf_matrix): 71 | if conf_matrix.sum() == 0: 72 | return 0, 0, 0, 0, 0 73 | with np.errstate(divide='ignore',invalid='ignore'): 74 | globalacc = np.diag(conf_matrix).sum() / conf_matrix.sum().astype(np.float32) 75 | classpre = np.diag(conf_matrix) / conf_matrix.sum(0).astype(np.float32) 76 | classrecall = np.diag(conf_matrix) / conf_matrix.sum(1).astype(np.float32) 77 | IU = np.diag(conf_matrix) / (conf_matrix.sum(1) + conf_matrix.sum(0) - np.diag(conf_matrix)).astype(np.float32) 78 | pre = classpre[1] 79 | recall = classrecall[1] 80 | iou = IU[1] 81 | F_score = 2*(recall*pre)/(recall+pre) 82 | return globalacc, pre, recall, F_score, iou 83 | --------------------------------------------------------------------------------