├── .gitignore ├── LICENSE ├── README.md ├── config ├── cityscapes_deeplabv3p.yaml ├── eval │ └── cityscapes_deeplab_v3_plus.yaml └── pascal_unet_res18_scse.yaml ├── data └── .gitkeep ├── logs └── .gitkeep ├── model └── .gitkeep ├── src ├── __init__.py ├── converter │ ├── convert_mobilenetv2.py │ └── convert_xception65.py ├── dataset │ ├── __init__.py │ ├── apolloscape.png │ ├── apolloscape.py │ ├── cityscapes.png │ ├── cityscapes.py │ ├── pascal_voc.png │ └── pascal_voc.py ├── eval.png ├── eval_cityscapes.py ├── logger │ ├── __init__.py │ ├── log.py │ └── plot.py ├── losses │ ├── binary │ │ ├── __init__.py │ │ ├── dice_loss.py │ │ ├── focal_loss.py │ │ └── lovasz_loss.py │ └── multi │ │ ├── __init__.py │ │ ├── focal_loss.py │ │ ├── lovasz_loss.py │ │ ├── ohem_loss.py │ │ ├── softiou_loss.py │ │ └── sym_loss.py ├── models │ ├── __init__.py │ ├── common.py │ ├── decoder.py │ ├── encoder.py │ ├── ibn.py │ ├── mobilenet.py │ ├── net.py │ ├── oc.py │ ├── scse.py │ ├── spp.py │ ├── tta.py │ └── xception.py ├── start_train.sh ├── stop_train.sh ├── train.py └── utils │ ├── __init__.py │ ├── custum_aug.py │ ├── functional.py │ ├── metrics.py │ ├── optimizer.py │ ├── preprocess.py │ ├── scheduler.py │ └── visualize.py └── tf_model └── .gitkeep /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | .static_storage/ 56 | .media/ 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # model directory 107 | model/* 108 | 109 | # temporary directory 110 | tmp/ 111 | 112 | # data 113 | data/* 114 | 115 | # checkpoint 116 | checkpoints/ 117 | 118 | # results 119 | output/* 120 | 121 | # logs 122 | logs/* 123 | 124 | # ref 125 | ref/ 126 | 127 | # notebooks 128 | notebooks 129 | 130 | tf_model/* 131 | 132 | notebook 133 | 134 | .idea/ 135 | 136 | !.gitkeep 137 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Hiroki Taniai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PytorchSegmentation 2 | This repository implements general network for semantic segmentation. 3 | You can train various networks like DeepLabV3+, PSPNet, UNet, etc., just by writing the config file. 4 | 5 | ![DeepLabV3+](src/eval.png) 6 | 7 | ## Pretrained model 8 | You can run pretrained model converted from [official tensorflow model](https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md). 9 | 10 | ### DeepLabV3+(Xception65+ASPP) 11 | ``` 12 | $ cd tf_model 13 | $ wget http://download.tensorflow.org/models/deeplabv3_cityscapes_train_2018_02_06.tar.gz 14 | $ tar -xvf deeplabv3_cityscapes_train_2018_02_06.tar.gz 15 | $ cd ../src 16 | $ python -m converter.convert_xception65 ../tf_model/deeplabv3_cityscapes_train/model.ckpt 19 ../model/cityscapes_deeplab_v3_plus/model.pth 17 | ``` 18 | 19 | Then you can test the performance of trained network. 20 | 21 | ``` 22 | $ python eval_cityscapes.py --tta 23 | ``` 24 | 25 | mIoU of cityscapes 26 | ``` 27 | $ pip install cityscapesScripts 28 | $ export CITYSCAPES_RESULTS=../output/cityscapes_val/cityscapes_deeplab_v3_plus_tta 29 | $ export CITYSCAPES_DATASET=../data/cityscapes 30 | $ csEvalPixelLevelSemanticLabeling 31 | ``` 32 | 33 | ``` 34 | classes IoU nIoU 35 | -------------------------------- 36 | road : 0.984 nan 37 | sidewalk : 0.866 nan 38 | building : 0.931 nan 39 | wall : 0.626 nan 40 | fence : 0.635 nan 41 | pole : 0.668 nan 42 | traffic light : 0.698 nan 43 | traffic sign : 0.800 nan 44 | vegetation : 0.929 nan 45 | terrain : 0.651 nan 46 | sky : 0.954 nan 47 | person : 0.832 0.645 48 | rider : 0.644 0.452 49 | car : 0.956 0.887 50 | truck : 0.869 0.420 51 | bus : 0.906 0.657 52 | train : 0.834 0.555 53 | motorcycle : 0.674 0.404 54 | bicycle : 0.783 0.605 55 | -------------------------------- 56 | Score Average : 0.802 0.578 57 | -------------------------------- 58 | 59 | 60 | categories IoU nIoU 61 | -------------------------------- 62 | flat : 0.988 nan 63 | construction : 0.937 nan 64 | object : 0.729 nan 65 | nature : 0.931 nan 66 | sky : 0.954 nan 67 | human : 0.842 0.667 68 | vehicle : 0.944 0.859 69 | -------------------------------- 70 | Score Average : 0.904 0.763 71 | -------------------------------- 72 | ``` 73 | 74 | ### MobilenetV2 75 | ``` 76 | $ cd tf_model 77 | $ wget http://download.tensorflow.org/models/deeplabv3_mnv2_cityscapes_train_2018_02_05.tar.gz 78 | $ tar -xvf deeplabv3_mnv2_cityscapes_train_2018_02_05.tar.gz 79 | $ cd ../src 80 | $ python -m converter.convert_mobilenetv2 ../tf_model/deeplabv3_mnv2_cityscapes_train/model.ckpt 19 ../model/cityscapes_mobilnetv2/model.pth 81 | ``` 82 | 83 | 84 | 85 | ## How to train 86 | In order to train model, you have only to setup config file. 87 | For example, write config file as below and save it as config/pascal_unet_res18_scse.yaml. 88 | 89 | ```yaml 90 | Net: 91 | enc_type: 'resnet18' 92 | dec_type: 'unet_scse' 93 | num_filters: 8 94 | pretrained: True 95 | Data: 96 | dataset: 'pascal' 97 | target_size: (512, 512) 98 | Train: 99 | max_epoch: 20 100 | batch_size: 2 101 | fp16: True 102 | resume: False 103 | pretrained_path: 104 | Loss: 105 | loss_type: 'Lovasz' 106 | ignore_index: 255 107 | Optimizer: 108 | mode: 'adam' 109 | base_lr: 0.001 110 | t_max: 10 111 | ``` 112 | 113 | Then you can train this model by: 114 | 115 | ``` 116 | $ python train.py ../config/pascal_unet_res18_scse.yaml 117 | ``` 118 | 119 | ## Dataset 120 | - Cityscapes 121 | - Pascal Voc 122 | - augmentation 123 | - http://home.bharathh.info/pubs/codes/SBD/download.html 124 | - https://github.com/TheLegendAli/DeepLab-Context/issues/10 125 | 126 | ## Directory tree 127 | ``` 128 | . 129 | ├── config 130 | ├── data 131 | │   ├── cityscapes 132 | │ │   ├── gtFine 133 | │ │   └── leftImg8bit 134 | │   └── pascal_voc_2012 135 | │ └── VOCdevkit 136 | │ └── VOC2012 137 | │ ├── JPEGImages 138 | │ ├── SegmentationClass 139 | │ └── SegmentationClassAug 140 | ├── logs 141 | ├── model 142 | └── src 143 | ├── dataset 144 | ├── logger 145 | ├── losses 146 | │   ├── binary 147 | │   └── multi 148 | ├── models 149 | └── utils 150 | ``` 151 | 152 | ## Environments 153 | - OS: Ubuntu18.04 154 | - python: 3.7.0 155 | - pytorch: 1.0.0 156 | - pretrainedmodels: 0.7.4 157 | - albumentations: 0.1.8 158 | 159 | if you want to train models in fp16 160 | - NVIDIA/apex: 0.1 161 | 162 | ## Reference 163 | 164 | ### Encoder 165 | - https://arxiv.org/abs/1505.04597 166 | - https://github.com/tugstugi/pytorch-saltnet 167 | 168 | ### Decoder 169 | #### SCSE 170 | - https://arxiv.org/abs/1803.02579 171 | 172 | #### IBN 173 | - https://arxiv.org/abs/1807.09441 174 | - https://github.com/XingangPan/IBN-Net 175 | - https://github.com/SeuTao/Kaggle_TGS2018_4th_solution 176 | 177 | #### OC 178 | - https://arxiv.org/abs/1809.00916 179 | - https://github.com/PkuRainBow/OCNet 180 | 181 | #### PSP 182 | - https://arxiv.org/abs/1612.01105 183 | 184 | #### ASPP 185 | - https://arxiv.org/abs/1802.02611 186 | -------------------------------------------------------------------------------- /config/cityscapes_deeplabv3p.yaml: -------------------------------------------------------------------------------- 1 | Net: 2 | enc_type: 'xception65' 3 | dec_type: 'aspp' 4 | output_stride: 8 5 | Data: 6 | dataset: 'cityscapes' 7 | target_size: (728, 728) 8 | Train: 9 | max_epoch: 60 10 | batch_size: 2 11 | fp16: True 12 | resume: False 13 | pretrained_path: 14 | Loss: 15 | loss_type: 'Lovasz' 16 | ignore_index: 255 17 | Optimizer: 18 | mode: 'sgd' 19 | base_lr: 0.007 20 | t_max: 30 21 | -------------------------------------------------------------------------------- /config/eval/cityscapes_deeplab_v3_plus.yaml: -------------------------------------------------------------------------------- 1 | # Config for evaluating cityscapes 2 | Net: 3 | enc_type: 'xception65' 4 | dec_type: 'aspp' 5 | output_stride: 8 6 | output_channels: 19 7 | -------------------------------------------------------------------------------- /config/pascal_unet_res18_scse.yaml: -------------------------------------------------------------------------------- 1 | Net: 2 | enc_type: 'resnet18' 3 | dec_type: 'unet_scse' 4 | num_filters: 8 5 | pretrained: True 6 | Data: 7 | dataset: 'pascal' 8 | target_size: (512, 512) 9 | Train: 10 | max_epoch: 20 11 | batch_size: 2 12 | fp16: True 13 | resume: False 14 | pretrained_path: 15 | Loss: 16 | loss_type: 'Lovasz' 17 | ignore_index: 255 18 | Optimizer: 19 | mode: 'adam' 20 | base_lr: 0.001 21 | t_max: 10 22 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyoki-mtl/pytorch-segmentation/ddf7ce97114afa0f18902ba4a8f6dd6db581bcec/data/.gitkeep -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyoki-mtl/pytorch-segmentation/ddf7ce97114afa0f18902ba4a8f6dd6db581bcec/logs/.gitkeep -------------------------------------------------------------------------------- /model/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyoki-mtl/pytorch-segmentation/ddf7ce97114afa0f18902ba4a8f6dd6db581bcec/model/.gitkeep -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyoki-mtl/pytorch-segmentation/ddf7ce97114afa0f18902ba4a8f6dd6db581bcec/src/__init__.py -------------------------------------------------------------------------------- /src/converter/convert_mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import tensorflow as tf 5 | import torch 6 | from models.net import SPPNet 7 | 8 | 9 | def convert_mobilenetv2(ckpt_path, num_classes): 10 | def conv_converter(pt_layer, tf_layer_name, depthwise=False, bias=False): 11 | if depthwise: 12 | pt_layer.weight.data = torch.Tensor( 13 | reader.get_tensor(f'{tf_layer_name}/depthwise_weights').transpose(2, 3, 0, 1)) 14 | else: 15 | pt_layer.weight.data = torch.Tensor(reader.get_tensor(f'{tf_layer_name}/weights').transpose(3, 2, 0, 1)) 16 | 17 | if bias: 18 | pt_layer.bias.data = torch.Tensor(reader.get_tensor(f'{tf_layer_name}/biases')) 19 | 20 | def bn_converter(pt_layer, tf_layer_name): 21 | pt_layer.bias.data = torch.Tensor(reader.get_tensor(f'{tf_layer_name}/beta')) 22 | pt_layer.weight.data = torch.Tensor(reader.get_tensor(f'{tf_layer_name}/gamma')) 23 | pt_layer.running_mean.data = torch.Tensor(reader.get_tensor(f'{tf_layer_name}/moving_mean')) 24 | pt_layer.running_var.data = torch.Tensor(reader.get_tensor(f'{tf_layer_name}/moving_variance')) 25 | 26 | def block_converter(pt_layer, tf_layer_name): 27 | if hasattr(pt_layer, 'expand'): 28 | conv_converter(pt_layer.expand.conv, f'{tf_layer_name}/expand') 29 | bn_converter(pt_layer.expand.bn, f'{tf_layer_name}/expand/BatchNorm') 30 | 31 | conv_converter(pt_layer.depthwise.conv, f'{tf_layer_name}/depthwise', depthwise=True) 32 | bn_converter(pt_layer.depthwise.bn, f'{tf_layer_name}/depthwise/BatchNorm') 33 | 34 | conv_converter(pt_layer.project.conv, f'{tf_layer_name}/project') 35 | bn_converter(pt_layer.project.bn, f'{tf_layer_name}/project/BatchNorm') 36 | 37 | reader = tf.train.NewCheckpointReader(ckpt_path) 38 | model = SPPNet(num_classes, enc_type='mobilenetv2', dec_type='maspp') 39 | 40 | # MobileNetV2 41 | conv_converter(model.encoder.conv, 'MobilenetV2/Conv') 42 | bn_converter(model.encoder.bn, 'MobilenetV2/Conv/BatchNorm') 43 | 44 | block_converter(model.encoder.block0, 'MobilenetV2/expanded_conv') 45 | block_converter(model.encoder.block1, 'MobilenetV2/expanded_conv_1') 46 | block_converter(model.encoder.block2, 'MobilenetV2/expanded_conv_2') 47 | block_converter(model.encoder.block3, 'MobilenetV2/expanded_conv_3') 48 | block_converter(model.encoder.block4, 'MobilenetV2/expanded_conv_4') 49 | block_converter(model.encoder.block5, 'MobilenetV2/expanded_conv_5') 50 | block_converter(model.encoder.block6, 'MobilenetV2/expanded_conv_6') 51 | block_converter(model.encoder.block7, 'MobilenetV2/expanded_conv_7') 52 | block_converter(model.encoder.block8, 'MobilenetV2/expanded_conv_8') 53 | block_converter(model.encoder.block9, 'MobilenetV2/expanded_conv_9') 54 | block_converter(model.encoder.block10, 'MobilenetV2/expanded_conv_10') 55 | block_converter(model.encoder.block11, 'MobilenetV2/expanded_conv_11') 56 | block_converter(model.encoder.block12, 'MobilenetV2/expanded_conv_12') 57 | block_converter(model.encoder.block13, 'MobilenetV2/expanded_conv_13') 58 | block_converter(model.encoder.block14, 'MobilenetV2/expanded_conv_14') 59 | block_converter(model.encoder.block15, 'MobilenetV2/expanded_conv_15') 60 | block_converter(model.encoder.block16, 'MobilenetV2/expanded_conv_16') 61 | 62 | # SPP 63 | conv_converter(model.spp.aspp0.conv, 'aspp0') 64 | bn_converter(model.spp.aspp0.bn, 'aspp0/BatchNorm') 65 | conv_converter(model.spp.image_pooling.conv, 'image_pooling') 66 | bn_converter(model.spp.image_pooling.bn, 'image_pooling/BatchNorm') 67 | conv_converter(model.spp.conv, 'concat_projection') 68 | bn_converter(model.spp.bn, 'concat_projection/BatchNorm') 69 | 70 | # Logits 71 | conv_converter(model.logits, 'logits/semantic', bias=True) 72 | 73 | return model 74 | 75 | 76 | if __name__ == '__main__': 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('ckpt_path') 79 | parser.add_argument('num_classes', type=int) 80 | parser.add_argument('output_path') 81 | args = parser.parse_args() 82 | 83 | ckpt_path = args.ckpt_path 84 | num_classes = args.num_classes 85 | output_path = Path(args.output_path) 86 | output_path.parent.mkdir() 87 | 88 | model = convert_mobilenetv2(ckpt_path, num_classes) 89 | torch.save(model.state_dict(), output_path) 90 | -------------------------------------------------------------------------------- /src/converter/convert_xception65.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import tensorflow as tf 5 | import torch 6 | from models.net import SPPNet 7 | 8 | 9 | def convert_xception65(ckpt_path, num_classes): 10 | def conv_converter(pt_layer, tf_layer_name, depthwise=False, bias=False): 11 | if depthwise: 12 | pt_layer.weight.data = torch.Tensor( 13 | reader.get_tensor(f'{tf_layer_name}/depthwise_weights').transpose(2, 3, 0, 1)) 14 | else: 15 | pt_layer.weight.data = torch.Tensor(reader.get_tensor(f'{tf_layer_name}/weights').transpose(3, 2, 0, 1)) 16 | 17 | if bias: 18 | pt_layer.bias.data = torch.Tensor(reader.get_tensor(f'{tf_layer_name}/biases')) 19 | 20 | def bn_converter(pt_layer, tf_layer_name): 21 | pt_layer.bias.data = torch.Tensor(reader.get_tensor(f'{tf_layer_name}/beta')) 22 | pt_layer.weight.data = torch.Tensor(reader.get_tensor(f'{tf_layer_name}/gamma')) 23 | pt_layer.running_mean.data = torch.Tensor(reader.get_tensor(f'{tf_layer_name}/moving_mean')) 24 | pt_layer.running_var.data = torch.Tensor(reader.get_tensor(f'{tf_layer_name}/moving_variance')) 25 | 26 | def sepconv_converter(pt_layer, tf_layer_name): 27 | conv_converter(pt_layer.depthwise, f'{tf_layer_name}_depthwise', True) 28 | bn_converter(pt_layer.bn_depth, f'{tf_layer_name}_depthwise/BatchNorm') 29 | conv_converter(pt_layer.pointwise, f'{tf_layer_name}_pointwise') 30 | bn_converter(pt_layer.bn_point, f'{tf_layer_name}_pointwise/BatchNorm') 31 | 32 | def block_converter(pt_block, tf_block_name): 33 | if pt_block.skip_connection_type == 'conv': 34 | conv_converter(pt_block.conv, f'{tf_block_name}/shortcut') 35 | bn_converter(pt_block.bn, f'{tf_block_name}/shortcut/BatchNorm') 36 | 37 | sepconv_converter(pt_block.sep_conv1.block, f'{tf_block_name}/separable_conv1') 38 | sepconv_converter(pt_block.sep_conv2.block, f'{tf_block_name}/separable_conv2') 39 | sepconv_converter(pt_block.sep_conv3.block, f'{tf_block_name}/separable_conv3') 40 | 41 | reader = tf.train.NewCheckpointReader(ckpt_path) 42 | model = SPPNet(num_classes, enc_type='xception65', dec_type='aspp', output_stride=8) 43 | 44 | # Xception 45 | ## Entry flow 46 | conv_converter(model.encoder.conv1, 'xception_65/entry_flow/conv1_1') 47 | bn_converter(model.encoder.bn1, 'xception_65/entry_flow/conv1_1/BatchNorm') 48 | conv_converter(model.encoder.conv2, 'xception_65/entry_flow/conv1_2') 49 | bn_converter(model.encoder.bn2, 'xception_65/entry_flow/conv1_2/BatchNorm') 50 | block_converter(model.encoder.block1, 'xception_65/entry_flow/block1/unit_1/xception_module') 51 | block_converter(model.encoder.block2, 'xception_65/entry_flow/block2/unit_1/xception_module') 52 | block_converter(model.encoder.block3, 'xception_65/entry_flow/block3/unit_1/xception_module') 53 | ## Middle flow 54 | block_converter(model.encoder.block4, 'xception_65/middle_flow/block1/unit_1/xception_module') 55 | block_converter(model.encoder.block5, 'xception_65/middle_flow/block1/unit_2/xception_module') 56 | block_converter(model.encoder.block6, 'xception_65/middle_flow/block1/unit_3/xception_module') 57 | block_converter(model.encoder.block7, 'xception_65/middle_flow/block1/unit_4/xception_module') 58 | block_converter(model.encoder.block8, 'xception_65/middle_flow/block1/unit_5/xception_module') 59 | block_converter(model.encoder.block9, 'xception_65/middle_flow/block1/unit_6/xception_module') 60 | block_converter(model.encoder.block10, 'xception_65/middle_flow/block1/unit_7/xception_module') 61 | block_converter(model.encoder.block11, 'xception_65/middle_flow/block1/unit_8/xception_module') 62 | block_converter(model.encoder.block12, 'xception_65/middle_flow/block1/unit_9/xception_module') 63 | block_converter(model.encoder.block13, 'xception_65/middle_flow/block1/unit_10/xception_module') 64 | block_converter(model.encoder.block14, 'xception_65/middle_flow/block1/unit_11/xception_module') 65 | block_converter(model.encoder.block15, 'xception_65/middle_flow/block1/unit_12/xception_module') 66 | block_converter(model.encoder.block16, 'xception_65/middle_flow/block1/unit_13/xception_module') 67 | block_converter(model.encoder.block17, 'xception_65/middle_flow/block1/unit_14/xception_module') 68 | block_converter(model.encoder.block18, 'xception_65/middle_flow/block1/unit_15/xception_module') 69 | block_converter(model.encoder.block19, 'xception_65/middle_flow/block1/unit_16/xception_module') 70 | ## Exit flow 71 | block_converter(model.encoder.block20, 'xception_65/exit_flow/block1/unit_1/xception_module') 72 | block_converter(model.encoder.block21, 'xception_65/exit_flow/block2/unit_1/xception_module') 73 | 74 | # ASPP 75 | conv_converter(model.spp.aspp0.conv, 'aspp0') 76 | bn_converter(model.spp.aspp0.bn, 'aspp0/BatchNorm') 77 | sepconv_converter(model.spp.aspp1.block, 'aspp1') 78 | sepconv_converter(model.spp.aspp2.block, 'aspp2') 79 | sepconv_converter(model.spp.aspp3.block, 'aspp3') 80 | 81 | conv_converter(model.spp.image_pooling.conv, 'image_pooling') 82 | bn_converter(model.spp.image_pooling.bn, 'image_pooling/BatchNorm') 83 | conv_converter(model.spp.conv, 'concat_projection') 84 | bn_converter(model.spp.bn, 'concat_projection/BatchNorm') 85 | 86 | # Decoder 87 | conv_converter(model.decoder.conv, 'decoder/feature_projection0') 88 | bn_converter(model.decoder.bn, 'decoder/feature_projection0/BatchNorm') 89 | 90 | sepconv_converter(model.decoder.sep1.block, 'decoder/decoder_conv0') 91 | sepconv_converter(model.decoder.sep2.block, 'decoder/decoder_conv1') 92 | 93 | # Logits 94 | conv_converter(model.logits, 'logits/semantic', bias=True) 95 | 96 | return model 97 | 98 | 99 | if __name__ == '__main__': 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('ckpt_path') 102 | parser.add_argument('num_classes', type=int) 103 | parser.add_argument('output_path') 104 | args = parser.parse_args() 105 | 106 | ckpt_path = args.ckpt_path 107 | num_classes = args.num_classes 108 | output_path = Path(args.output_path) 109 | output_path.parent.mkdir() 110 | 111 | model = convert_xception65(ckpt_path, num_classes) 112 | torch.save(model.state_dict(), output_path) 113 | -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyoki-mtl/pytorch-segmentation/ddf7ce97114afa0f18902ba4a8f6dd6db581bcec/src/dataset/__init__.py -------------------------------------------------------------------------------- /src/dataset/apolloscape.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyoki-mtl/pytorch-segmentation/ddf7ce97114afa0f18902ba4a8f6dd6db581bcec/src/dataset/apolloscape.png -------------------------------------------------------------------------------- /src/dataset/apolloscape.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from pathlib import Path 4 | import albumentations as albu 5 | 6 | import torch 7 | from torch.utils.data import DataLoader, Dataset 8 | from torchvision import transforms 9 | 10 | # class definitions -> http://apolloscape.auto/scene.html 11 | n_classes = 36 12 | void_classes = [0, 1, 17, 34, 162, 35, 163, 37, 165, 38, 166, 39, 167, 40, 168, 50, 65, 66, 67, 81, 82, 83, 84, 85, 86, 97, 98, 99, 100, 113, 255] 13 | valid_classes = [33, 161, 36, 164, 49, 81] 14 | class_map = dict(zip(valid_classes, range(n_classes))) 15 | 16 | 17 | class ApolloscapeDataset(Dataset): 18 | def __init__(self, 19 | base_dir='../../data/apolloscape', 20 | road_record_list=[{'road':'road02_seg','record':[22, 23, 24, 25, 26]}, {'road':'road03_seg', 'record':[7, 8, 9, 10, 11, 12]}], 21 | split='train', 22 | ignore_index=255, 23 | debug=False): 24 | self.debug = debug 25 | self.base_dir = Path(base_dir) 26 | self.ignore_index = ignore_index 27 | self.split = split 28 | self.img_paths = [] 29 | self.lbl_paths = [] 30 | 31 | for road_record in road_record_list: 32 | self.road_dir = self.base_dir / Path(road_record['road']) 33 | self.record_list = road_record['record'] 34 | 35 | for record in self.record_list: 36 | img_paths_tmp = self.road_dir.glob(f'ColorImage/Record{record:03}/Camera 5/*.jpg') 37 | lbl_paths_tmp = self.road_dir.glob(f'Label/Record{record:03}/Camera 5/*.png') 38 | 39 | img_paths_basenames = {Path(img_path.name).stem for img_path in img_paths_tmp} 40 | lbl_paths_basenames = {Path(lbl_path.name).stem.replace('_bin', '') for lbl_path in lbl_paths_tmp} 41 | 42 | intersection_basenames = img_paths_basenames & lbl_paths_basenames 43 | 44 | img_paths_intersection = [self.road_dir / Path(f'ColorImage/Record{record:03}/Camera 5/{intersection_basename}.jpg') 45 | for intersection_basename in intersection_basenames] 46 | lbl_paths_intersection = [self.road_dir / Path(f'Label/Record{record:03}/Camera 5/{intersection_basename}_bin.png') 47 | for intersection_basename in intersection_basenames] 48 | 49 | self.img_paths += img_paths_intersection 50 | self.lbl_paths += lbl_paths_intersection 51 | 52 | self.img_paths.sort() 53 | self.lbl_paths.sort() 54 | print(len(self.img_paths), len(self.lbl_paths)) 55 | assert len(self.img_paths) == len(self.lbl_paths) 56 | 57 | self.resizer = albu.Resize(height=512, width=1024) 58 | self.augmenter = albu.Compose([albu.HorizontalFlip(p=0.5), 59 | # albu.RandomRotate90(p=0.5), 60 | albu.Rotate(limit=10, p=0.5), 61 | # albu.CLAHE(p=0.2), 62 | # albu.RandomContrast(p=0.2), 63 | # albu.RandomBrightness(p=0.2), 64 | # albu.RandomGamma(p=0.2), 65 | # albu.GaussNoise(p=0.2), 66 | # albu.Cutout(p=0.2) 67 | ]) 68 | self.img_transformer = transforms.Compose([transforms.ToTensor(), 69 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 70 | std=[0.229, 0.224, 0.225])]) 71 | self.lbl_transformer = torch.LongTensor 72 | 73 | def __len__(self): 74 | return len(self.img_paths) 75 | 76 | def __getitem__(self, index): 77 | img_path = self.img_paths[index] 78 | lbl_path = self.lbl_paths[index] 79 | 80 | img = np.array(Image.open(img_path)) 81 | lbl = np.array(Image.open(lbl_path)) 82 | for c in void_classes: 83 | lbl[lbl == c] = self.ignore_index 84 | for c in valid_classes: 85 | lbl[lbl == c] = class_map[c] 86 | 87 | resized = self.resizer(image=img, mask=lbl) 88 | img, lbl = resized['image'], resized['mask'] 89 | 90 | if self.split == 'train': 91 | augmented = self.augmenter(image=img, mask=lbl) 92 | img, lbl = augmented['image'], augmented['mask'] 93 | 94 | if self.debug: 95 | print(np.unique(lbl)) 96 | else: 97 | img = self.img_transformer(img) 98 | lbl = self.lbl_transformer(lbl) 99 | 100 | return img, lbl, img_path.stem 101 | 102 | 103 | if __name__ == '__main__': 104 | import matplotlib 105 | matplotlib.use('Agg') 106 | import matplotlib.pyplot as plt 107 | 108 | dataset = ApolloscapeDataset(base_dir='../../data/apolloscape', debug=True) 109 | dataloader = DataLoader(dataset, batch_size=8, shuffle=True) 110 | print(len(dataset)) 111 | 112 | for i, batched in enumerate(dataloader): 113 | images, labels, _ = batched 114 | if i == 0: 115 | fig, axes = plt.subplots(8, 2, figsize=(10, 30)) 116 | plt.tight_layout() 117 | for j in range(8): 118 | axes[j][0].imshow(images[j]) 119 | axes[j][1].imshow(labels[j]) 120 | plt.savefig('apolloscape.png') 121 | plt.close() 122 | break 123 | -------------------------------------------------------------------------------- /src/dataset/cityscapes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyoki-mtl/pytorch-segmentation/ddf7ce97114afa0f18902ba4a8f6dd6db581bcec/src/dataset/cityscapes.png -------------------------------------------------------------------------------- /src/dataset/cityscapes.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | from PIL import Image 4 | from pathlib import Path 5 | import albumentations as albu 6 | 7 | import torch 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | from utils.preprocess import minmax_normalize, meanstd_normalize 11 | from utils.custum_aug import PadIfNeededRightBottom 12 | 13 | 14 | class CityscapesDataset(Dataset): 15 | n_classes = 19 16 | void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 17 | valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] 18 | class_map = dict(zip(valid_classes, range(n_classes))) 19 | 20 | def __init__(self, base_dir='../data/cityscapes', split='train', 21 | affine_augmenter=None, image_augmenter=None, target_size=(1024, 2048), 22 | net_type='unet', ignore_index=255, debug=False): 23 | self.debug = debug 24 | self.base_dir = Path(base_dir) 25 | assert net_type in ['unet', 'deeplab'] 26 | self.net_type = net_type 27 | self.ignore_index = ignore_index 28 | self.split = 'val' if split == 'valid' else split 29 | 30 | self.img_paths = sorted(self.base_dir.glob(f'leftImg8bit/{self.split}/*/*leftImg8bit.png')) 31 | self.lbl_paths = sorted(self.base_dir.glob(f'gtFine/{self.split}/*/*gtFine_labelIds.png')) 32 | assert len(self.img_paths) == len(self.lbl_paths) 33 | 34 | # Resize 35 | if isinstance(target_size, str): 36 | target_size = eval(target_size) 37 | if self.split == 'train': 38 | if self.net_type == 'deeplab': 39 | target_size = (target_size[0] + 1, target_size[1] + 1) 40 | self.resizer = albu.Compose([albu.RandomScale(scale_limit=(-0.5, 0.5), p=1.0), 41 | PadIfNeededRightBottom(min_height=target_size[0], min_width=target_size[1], 42 | value=0, ignore_index=self.ignore_index, p=1.0), 43 | albu.RandomCrop(height=target_size[0], width=target_size[1], p=1.0)]) 44 | else: 45 | self.resizer = None 46 | 47 | # Augment 48 | if self.split == 'train': 49 | self.affine_augmenter = affine_augmenter 50 | self.image_augmenter = image_augmenter 51 | else: 52 | self.affine_augmenter = None 53 | self.image_augmenter = None 54 | 55 | def __len__(self): 56 | return len(self.img_paths) 57 | 58 | def __getitem__(self, index): 59 | img_path = self.img_paths[index] 60 | img = np.array(Image.open(img_path)) 61 | if self.split == 'test': 62 | # Resize (Scale & Pad & Crop) 63 | if self.net_type == 'unet': 64 | img = minmax_normalize(img) 65 | img = meanstd_normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 66 | else: 67 | img = minmax_normalize(img, norm_range=(-1, 1)) 68 | if self.resizer: 69 | resized = self.resizer(image=img) 70 | img = resized['image'] 71 | img = img.transpose(2, 0, 1) 72 | img = torch.FloatTensor(img) 73 | return img 74 | else: 75 | lbl_path = self.lbl_paths[index] 76 | lbl = np.array(Image.open(lbl_path)) 77 | lbl = self.encode_mask(lbl) 78 | # ImageAugment (RandomBrightness, AddNoise...) 79 | if self.image_augmenter: 80 | augmented = self.image_augmenter(image=img) 81 | img = augmented['image'] 82 | # Resize (Scale & Pad & Crop) 83 | if self.net_type == 'unet': 84 | img = minmax_normalize(img) 85 | img = meanstd_normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 86 | else: 87 | img = minmax_normalize(img, norm_range=(-1, 1)) 88 | if self.resizer: 89 | resized = self.resizer(image=img, mask=lbl) 90 | img, lbl = resized['image'], resized['mask'] 91 | # AffineAugment (Horizontal Flip, Rotate...) 92 | if self.affine_augmenter: 93 | augmented = self.affine_augmenter(image=img, mask=lbl) 94 | img, lbl = augmented['image'], augmented['mask'] 95 | 96 | if self.debug: 97 | print(lbl_path) 98 | print(np.unique(lbl)) 99 | else: 100 | img = img.transpose(2, 0, 1) 101 | img = torch.FloatTensor(img) 102 | lbl = torch.LongTensor(lbl) 103 | return img, lbl, img_path.stem 104 | 105 | def encode_mask(self, lbl): 106 | for c in self.void_classes: 107 | lbl[lbl == c] = self.ignore_index 108 | for c in self.valid_classes: 109 | lbl[lbl == c] = self.class_map[c] 110 | return lbl 111 | 112 | 113 | if __name__ == '__main__': 114 | import matplotlib 115 | 116 | matplotlib.use('Agg') 117 | import matplotlib.pyplot as plt 118 | from utils.custum_aug import Rotate 119 | 120 | affine_augmenter = albu.Compose([albu.HorizontalFlip(p=.5), 121 | # Rotate(5, p=.5) 122 | ]) 123 | # image_augmenter = albu.Compose([albu.GaussNoise(p=.5), 124 | # albu.RandomBrightnessContrast(p=.5)]) 125 | image_augmenter = None 126 | dataset = CityscapesDataset(split='train', net_type='deeplab', ignore_index=19, debug=True, 127 | affine_augmenter=affine_augmenter, image_augmenter=image_augmenter) 128 | dataloader = DataLoader(dataset, batch_size=8, shuffle=True) 129 | print(len(dataset)) 130 | 131 | for i, batched in enumerate(dataloader): 132 | images, labels, _ = batched 133 | if i == 0: 134 | fig, axes = plt.subplots(8, 2, figsize=(20, 48)) 135 | plt.tight_layout() 136 | for j in range(8): 137 | axes[j][0].imshow(minmax_normalize(images[j], norm_range=(0, 1), orig_range=(-1, 1))) 138 | axes[j][1].imshow(labels[j]) 139 | axes[j][0].set_xticks([]) 140 | axes[j][0].set_yticks([]) 141 | axes[j][1].set_xticks([]) 142 | axes[j][1].set_yticks([]) 143 | plt.savefig('dataset/cityscapes.png') 144 | plt.close() 145 | break 146 | -------------------------------------------------------------------------------- /src/dataset/pascal_voc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyoki-mtl/pytorch-segmentation/ddf7ce97114afa0f18902ba4a8f6dd6db581bcec/src/dataset/pascal_voc.png -------------------------------------------------------------------------------- /src/dataset/pascal_voc.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | from PIL import Image 4 | from pathlib import Path 5 | import albumentations as albu 6 | 7 | import torch 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | from utils.preprocess import minmax_normalize, meanstd_normalize 11 | from utils.custum_aug import PadIfNeededRightBottom 12 | 13 | 14 | class PascalVocDataset(Dataset): 15 | n_classes = 21 16 | 17 | def __init__(self, base_dir='../data/pascal_voc_2012/VOCdevkit/VOC2012', split='train_aug', 18 | affine_augmenter=None, image_augmenter=None, target_size=(512, 512), 19 | net_type='unet', ignore_index=255, debug=False): 20 | self.debug = debug 21 | self.base_dir = Path(base_dir) 22 | assert net_type in ['unet', 'deeplab'] 23 | self.net_type = net_type 24 | self.ignore_index = ignore_index 25 | self.split = split 26 | 27 | valid_ids = self.base_dir / 'ImageSets' / 'Segmentation' / 'val.txt' 28 | with open(valid_ids, 'r') as f: 29 | valid_ids = f.readlines() 30 | if self.split == 'valid': 31 | lbl_dir = 'SegmentationClass' 32 | img_ids = valid_ids 33 | else: 34 | valid_set = set([valid_id.strip() for valid_id in valid_ids]) 35 | lbl_dir = 'SegmentationClassAug' if 'aug' in split else 'SegmentationClass' 36 | all_set = set([p.name[:-4] for p in self.base_dir.joinpath(lbl_dir).iterdir()]) 37 | img_ids = list(all_set - valid_set) 38 | self.img_paths = [(self.base_dir / 'JPEGImages' / f'{img_id.strip()}.jpg') for img_id in img_ids] 39 | self.lbl_paths = [(self.base_dir / lbl_dir / f'{img_id.strip()}.png') for img_id in img_ids] 40 | 41 | # Resize 42 | if isinstance(target_size, str): 43 | target_size = eval(target_size) 44 | if 'train' in self.split: 45 | if self.net_type == 'deeplab': 46 | target_size = (target_size[0] + 1, target_size[1] + 1) 47 | self.resizer = albu.Compose([albu.RandomScale(scale_limit=(-0.5, 0.5), p=1.0), 48 | PadIfNeededRightBottom(min_height=target_size[0], min_width=target_size[1], 49 | value=0, ignore_index=self.ignore_index, p=1.0), 50 | albu.RandomCrop(height=target_size[0], width=target_size[1], p=1.0)]) 51 | else: 52 | # self.resizer = None 53 | self.resizer = albu.Compose([PadIfNeededRightBottom(min_height=target_size[0], min_width=target_size[1], 54 | value=0, ignore_index=self.ignore_index, p=1.0), 55 | albu.Crop(x_min=0, x_max=target_size[1], 56 | y_min=0, y_max=target_size[0])]) 57 | 58 | # Augment 59 | if 'train' in self.split: 60 | self.affine_augmenter = affine_augmenter 61 | self.image_augmenter = image_augmenter 62 | else: 63 | self.affine_augmenter = None 64 | self.image_augmenter = None 65 | 66 | def __len__(self): 67 | return len(self.img_paths) 68 | 69 | def __getitem__(self, index): 70 | img_path = self.img_paths[index] 71 | img = np.array(Image.open(img_path)) 72 | if self.split == 'test': 73 | # Resize (Scale & Pad & Crop) 74 | if self.net_type == 'unet': 75 | img = minmax_normalize(img) 76 | img = meanstd_normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 77 | else: 78 | img = minmax_normalize(img, norm_range=(-1, 1)) 79 | if self.resizer: 80 | resized = self.resizer(image=img) 81 | img = resized['image'] 82 | img = img.transpose(2, 0, 1) 83 | img = torch.FloatTensor(img) 84 | return img 85 | else: 86 | lbl_path = self.lbl_paths[index] 87 | lbl = np.array(Image.open(lbl_path)) 88 | lbl[lbl == 255] = 0 89 | # ImageAugment (RandomBrightness, AddNoise...) 90 | if self.image_augmenter: 91 | augmented = self.image_augmenter(image=img) 92 | img = augmented['image'] 93 | # Resize (Scale & Pad & Crop) 94 | if self.net_type == 'unet': 95 | img = minmax_normalize(img) 96 | img = meanstd_normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 97 | else: 98 | img = minmax_normalize(img, norm_range=(-1, 1)) 99 | if self.resizer: 100 | resized = self.resizer(image=img, mask=lbl) 101 | img, lbl = resized['image'], resized['mask'] 102 | # AffineAugment (Horizontal Flip, Rotate...) 103 | if self.affine_augmenter: 104 | augmented = self.affine_augmenter(image=img, mask=lbl) 105 | img, lbl = augmented['image'], augmented['mask'] 106 | 107 | if self.debug: 108 | print(lbl_path) 109 | print(lbl.shape) 110 | print(np.unique(lbl)) 111 | else: 112 | img = img.transpose(2, 0, 1) 113 | img = torch.FloatTensor(img) 114 | lbl = torch.LongTensor(lbl) 115 | return img, lbl, img_path.stem 116 | 117 | 118 | if __name__ == '__main__': 119 | import matplotlib 120 | matplotlib.use('Agg') 121 | import matplotlib.pyplot as plt 122 | from utils.custum_aug import Rotate 123 | 124 | affine_augmenter = albu.Compose([albu.HorizontalFlip(p=.5), 125 | Rotate(5, p=.5) 126 | ]) 127 | # image_augmenter = albu.Compose([albu.GaussNoise(p=.5), 128 | # albu.RandomBrightnessContrast(p=.5)]) 129 | image_augmenter = None 130 | dataset = PascalVocDataset(affine_augmenter=affine_augmenter, image_augmenter=image_augmenter, split='valid', 131 | net_type='deeplab', ignore_index=21, target_size=(512, 512), debug=True) 132 | dataloader = DataLoader(dataset, batch_size=8, shuffle=True) 133 | print(len(dataset)) 134 | 135 | for i, batched in enumerate(dataloader): 136 | images, labels, _ = batched 137 | if i == 0: 138 | fig, axes = plt.subplots(8, 2, figsize=(20, 48)) 139 | plt.tight_layout() 140 | for j in range(8): 141 | axes[j][0].imshow(minmax_normalize(images[j], norm_range=(0, 1), orig_range=(-1, 1))) 142 | axes[j][1].imshow(labels[j]) 143 | axes[j][0].set_xticks([]) 144 | axes[j][0].set_yticks([]) 145 | axes[j][1].set_xticks([]) 146 | axes[j][1].set_yticks([]) 147 | plt.savefig('dataset/pascal_voc.png') 148 | plt.close() 149 | break 150 | -------------------------------------------------------------------------------- /src/eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyoki-mtl/pytorch-segmentation/ddf7ce97114afa0f18902ba4a8f6dd6db581bcec/src/eval.png -------------------------------------------------------------------------------- /src/eval_cityscapes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import numpy as np 4 | from pathlib import Path 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | 12 | import torch 13 | from torch.utils.data import DataLoader 14 | 15 | from models.net import EncoderDecoderNet, SPPNet 16 | from dataset.cityscapes import CityscapesDataset 17 | from utils.preprocess import minmax_normalize 18 | 19 | valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] 20 | id2cls_dict = dict(zip(range(19), valid_classes)) 21 | id2cls_func = np.vectorize(id2cls_dict.get) 22 | 23 | def predict(batched, tta_flag=False): 24 | images, labels, names = batched 25 | images_np = images.numpy().transpose(0, 2, 3, 1) 26 | labels_np = labels.numpy() 27 | 28 | images, labels = images.to(device), labels.to(device) 29 | if tta_flag: 30 | preds = model.tta(images, scales=scales, net_type=net_type) 31 | else: 32 | preds = model.pred_resize(images, images.shape[2:], net_type=net_type) 33 | preds = preds.argmax(dim=1) 34 | preds_np = preds.detach().cpu().numpy().astype(np.uint8) 35 | return images_np, labels_np, preds_np, names 36 | 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('config_path') 39 | parser.add_argument('--tta', action='store_true') 40 | parser.add_argument('--vis', action='store_true') 41 | args = parser.parse_args() 42 | config_path = Path(args.config_path) 43 | tta_flag = args.tta 44 | vis_flag = args.vis 45 | 46 | config = yaml.load(open(config_path)) 47 | net_config = config['Net'] 48 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 49 | 50 | modelname = config_path.stem 51 | model_path = Path('../model') / modelname / 'model.pth' 52 | 53 | if 'unet' in net_config['dec_type']: 54 | net_type = 'unet' 55 | model = EncoderDecoderNet(**net_config) 56 | else: 57 | net_type = 'deeplab' 58 | model = SPPNet(**net_config) 59 | model.to(device) 60 | model.update_bn_eps() 61 | 62 | param = torch.load(model_path) 63 | model.load_state_dict(param) 64 | del param 65 | 66 | model.eval() 67 | 68 | batch_size = 1 69 | scales = [0.25, 0.75, 1, 1.25] 70 | valid_dataset = CityscapesDataset(split='valid', net_type=net_type) 71 | valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True) 72 | 73 | if vis_flag: 74 | images_list = [] 75 | labels_list = [] 76 | preds_list = [] 77 | 78 | with torch.no_grad(): 79 | for batched in valid_loader: 80 | images_np, labels_np, preds_np, names = predict(batched) 81 | images_list.append(images_np) 82 | labels_list.append(labels_np) 83 | preds_list.append(preds_np) 84 | if len(images_list) == 4: 85 | break 86 | 87 | images = np.concatenate(images_list) 88 | labels = np.concatenate(labels_list) 89 | preds = np.concatenate(preds_list) 90 | 91 | ignore_pixel = labels == 255 92 | preds[ignore_pixel] = 20 93 | labels[ignore_pixel] = 20 94 | 95 | fig, axes = plt.subplots(4, 3, figsize=(12, 10)) 96 | plt.tight_layout() 97 | 98 | axes[0, 0].set_title('input image') 99 | axes[0, 1].set_title('prediction') 100 | axes[0, 2].set_title('ground truth') 101 | 102 | for ax, img, lbl, pred in zip(axes, images, labels, preds): 103 | ax[0].imshow(minmax_normalize(img, norm_range=(0, 1), orig_range=(-1, 1))) 104 | ax[1].imshow(pred) 105 | ax[2].imshow(lbl) 106 | ax[0].set_xticks([]) 107 | ax[0].set_yticks([]) 108 | ax[1].set_xticks([]) 109 | ax[1].set_yticks([]) 110 | ax[2].set_xticks([]) 111 | ax[2].set_yticks([]) 112 | 113 | plt.savefig('eval.png') 114 | plt.close() 115 | else: 116 | output_dir = Path('../output/cityscapes_val') / (str(modelname) + '_tta' if tta_flag else modelname) 117 | output_dir.mkdir(parents=True) 118 | 119 | with torch.no_grad(): 120 | for batched in tqdm(valid_loader): 121 | _, _, preds_np, names = predict(batched) 122 | preds_np = id2cls_func(preds_np).astype(np.uint8) 123 | for name, pred in zip(names, preds_np): 124 | Image.fromarray(pred).save(output_dir / f'{name}.png') 125 | -------------------------------------------------------------------------------- /src/logger/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyoki-mtl/pytorch-segmentation/ddf7ce97114afa0f18902ba4a8f6dd6db581bcec/src/logger/__init__.py -------------------------------------------------------------------------------- /src/logger/log.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger, StreamHandler, INFO, DEBUG, Formatter, FileHandler 2 | 3 | 4 | def debug_logger(log_dir): 5 | logger = getLogger('train') 6 | logger.setLevel(DEBUG) 7 | 8 | fmt = Formatter('%(asctime)s %(name)s %(lineno)d [%(levelname)s][%(funcName)s] %(message)s') 9 | 10 | sh = StreamHandler() 11 | sh.setLevel(INFO) 12 | sh.setFormatter(fmt) 13 | logger.addHandler(sh) 14 | 15 | fh = FileHandler(filename=log_dir.joinpath('debug.txt'), mode='w') 16 | fh.setLevel(DEBUG) 17 | fh.setFormatter(fmt) 18 | logger.addHandler(fh) 19 | return logger 20 | -------------------------------------------------------------------------------- /src/logger/plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def history_ploter(history, path): 8 | history = np.asarray(history) 9 | title = path.name[:-4] 10 | fig = plt.figure() 11 | ax = fig.add_subplot(111) 12 | x = np.arange(len(history)) 13 | if history.ndim == 1: 14 | y = history 15 | ax.plot(x[y != None], y[y != None]) 16 | else: 17 | y = history[:, 0] 18 | ax.plot(x[y != None], y[y != None], label='train') 19 | y = history[:, 1] 20 | ax.plot(x[y != None], y[y != None], label='valid') 21 | ax.legend() 22 | ax.set_title(title) 23 | plt.savefig(str(path)) 24 | plt.close() 25 | -------------------------------------------------------------------------------- /src/losses/binary/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .focal_loss import FocalLoss 4 | from .lovasz_loss import LovaszLoss 5 | from .dice_loss import DiceLoss, MixedDiceBCELoss 6 | 7 | 8 | class BinaryClassCriterion(nn.Module): 9 | def __init__(self, loss_type='BCE', **kwargs): 10 | super().__init__() 11 | if loss_type == 'BCE': 12 | self.criterion = nn.BCEWithLogitsLoss(**kwargs) 13 | elif loss_type == 'Focal': 14 | self.criterion = FocalLoss(**kwargs) 15 | elif loss_type == 'Lovasz': 16 | self.criterion = LovaszLoss(**kwargs) 17 | elif loss_type == 'Dice': 18 | self.criterion = DiceLoss(**kwargs) 19 | elif loss_type == 'MixedDiceBCE': 20 | self.criterion = MixedDiceBCELoss(**kwargs) 21 | else: 22 | raise NotImplementedError 23 | 24 | def forward(self, preds, labels): 25 | loss = self.criterion(preds, labels) 26 | return loss 27 | -------------------------------------------------------------------------------- /src/losses/binary/dice_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DiceLoss(nn.Module): 6 | def __init__(self, smooth=0, eps=1e-7): 7 | super().__init__() 8 | self.smooth = smooth 9 | self.eps = eps 10 | 11 | def forward(self, preds, labels): 12 | return 1 - (2 * torch.sum(preds * labels) + self.smooth) / \ 13 | (torch.sum(preds) + torch.sum(labels) + self.smooth + self.eps) 14 | 15 | 16 | class MixedDiceBCELoss(nn.Module): 17 | def __init__(self, dice_weight=0.2, bce_weight=0.9): 18 | super().__init__() 19 | self.dice_loss = DiceLoss() 20 | self.bce_loss = nn.BCELoss() 21 | self.dice_weight = dice_weight 22 | self.bce_weight = bce_weight 23 | 24 | def forward(self, preds, labels): 25 | preds = torch.sigmoid(preds) 26 | loss = self.dice_weight * self.dice_loss(preds, labels) + self.bce_weight * self.bce_loss(preds, labels) 27 | return loss 28 | -------------------------------------------------------------------------------- /src/losses/binary/focal_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://arxiv.org/abs/1708.02002 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class FocalLoss(nn.Module): 9 | def __init__(self, alpha=0.25, gamma=2, weight=None, ignore_index=255): 10 | super(FocalLoss, self).__init__() 11 | self.alpha = alpha 12 | self.gamma = gamma 13 | self.weight = weight 14 | self.ignore_index = ignore_index 15 | self.bce_fn = nn.BCEWithLogitsLoss(weight=self.weight) 16 | 17 | def forward(self, preds, labels): 18 | if self.ignore_index is not None: 19 | mask = labels != self.ignore_index 20 | labels = labels[mask] 21 | preds = preds[mask] 22 | 23 | logpt = -self.bce_fn(preds, labels) 24 | pt = torch.exp(logpt) 25 | loss = -((1 - pt) ** self.gamma) * self.alpha * logpt 26 | return loss 27 | -------------------------------------------------------------------------------- /src/losses/binary/lovasz_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def lovasz_grad(gt_sorted): 12 | """ 13 | Computes gradient of the Lovasz extension w.r.t sorted errors 14 | See Alg. 1 in paper 15 | """ 16 | p = len(gt_sorted) 17 | gts = gt_sorted.sum() 18 | intersection = gts - gt_sorted.float().cumsum(0) 19 | union = gts + (1 - gt_sorted).float().cumsum(0) 20 | jaccard = 1 - intersection / union 21 | if p > 1: # cover 1-pixel case 22 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 23 | return jaccard 24 | 25 | 26 | def hinge(pred, label): 27 | signs = 2 * label - 1 28 | errors = 1 - pred * signs 29 | return errors 30 | 31 | 32 | def lovasz_hinge_flat(logits, labels, ignore_index): 33 | """ 34 | Binary Lovasz hinge loss 35 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 36 | labels: [P] Tensor, binary ground truth labels (0 or 1) 37 | ignore_index: label to ignore 38 | """ 39 | logits = logits.contiguous().view(-1) 40 | labels = labels.contiguous().view(-1) 41 | if ignore_index is not None: 42 | mask = labels != ignore_index 43 | logits = logits[mask] 44 | labels = labels[mask] 45 | errors = hinge(logits, labels) 46 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 47 | perm = perm.data 48 | gt_sorted = labels[perm] 49 | grad = lovasz_grad(gt_sorted) 50 | loss = torch.dot(F.elu(errors_sorted) + 1, grad) 51 | return loss 52 | 53 | 54 | class LovaszLoss(nn.Module): 55 | """ 56 | Binary Lovasz hinge loss 57 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 58 | labels: [P] Tensor, binary ground truth labels (0 or 1) 59 | ignore_index: label to ignore 60 | """ 61 | def __init__(self, ignore_index=None): 62 | super().__init__() 63 | self.ignore_index = ignore_index 64 | 65 | def forward(self, logits, labels): 66 | return lovasz_hinge_flat(logits, labels, self.ignore_index) 67 | -------------------------------------------------------------------------------- /src/losses/multi/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .focal_loss import FocalLoss 4 | from .lovasz_loss import LovaszSoftmax 5 | from .ohem_loss import OhemCrossEntropy2d 6 | from .softiou_loss import SoftIoULoss 7 | 8 | 9 | class MultiClassCriterion(nn.Module): 10 | def __init__(self, loss_type='CrossEntropy', **kwargs): 11 | super().__init__() 12 | if loss_type == 'CrossEntropy': 13 | self.criterion = nn.CrossEntropyLoss(**kwargs) 14 | elif loss_type == 'Focal': 15 | self.criterion = FocalLoss(**kwargs) 16 | elif loss_type == 'Lovasz': 17 | self.criterion = LovaszSoftmax(**kwargs) 18 | elif loss_type == 'OhemCrossEntropy': 19 | self.criterion = OhemCrossEntropy2d(**kwargs) 20 | elif loss_type == 'SoftIOU': 21 | self.criterion = SoftIoULoss(**kwargs) 22 | else: 23 | raise NotImplementedError 24 | 25 | def forward(self, preds, labels): 26 | loss = self.criterion(preds, labels) 27 | return loss 28 | -------------------------------------------------------------------------------- /src/losses/multi/focal_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://arxiv.org/abs/1708.02002 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class FocalLoss(nn.Module): 9 | def __init__(self, alpha=0.5, gamma=2, weight=None, ignore_index=255): 10 | super().__init__() 11 | self.alpha = alpha 12 | self.gamma = gamma 13 | self.weight = weight 14 | self.ignore_index = ignore_index 15 | self.ce_fn = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index) 16 | 17 | def forward(self, preds, labels): 18 | logpt = -self.ce_fn(preds, labels) 19 | pt = torch.exp(logpt) 20 | loss = -((1 - pt) ** self.gamma) * self.alpha * logpt 21 | return loss 22 | -------------------------------------------------------------------------------- /src/losses/multi/lovasz_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def lovasz_grad(gt_sorted): 12 | """ 13 | Computes gradient of the Lovasz extension w.r.t sorted errors 14 | See Alg. 1 in paper 15 | """ 16 | p = len(gt_sorted) 17 | gts = gt_sorted.sum() 18 | intersection = gts - gt_sorted.float().cumsum(0) 19 | union = gts + (1 - gt_sorted).float().cumsum(0) 20 | jaccard = 1 - intersection / union 21 | if p > 1: # cover 1-pixel case 22 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 23 | return jaccard 24 | 25 | 26 | def lovasz_softmax_flat(prb, lbl, ignore_index, only_present): 27 | """ 28 | Multi-class Lovasz-Softmax loss 29 | prb: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 30 | lbl: [P] Tensor, ground truth labels (between 0 and C - 1) 31 | ignore_index: void class labels 32 | only_present: average only on classes present in ground truth 33 | """ 34 | C = prb.shape[0] 35 | prb = prb.permute(1, 2, 0).contiguous().view(-1, C) # H * W, C 36 | lbl = lbl.view(-1) # H * W 37 | if ignore_index is not None: 38 | mask = lbl != ignore_index 39 | if mask.sum() == 0: 40 | return torch.mean(prb * 0) 41 | prb = prb[mask] 42 | lbl = lbl[mask] 43 | 44 | total_loss = 0 45 | cnt = 0 46 | for c in range(C): 47 | fg = (lbl == c).float() # foreground for class c 48 | if only_present and fg.sum() == 0: 49 | continue 50 | errors = (fg - prb[:, c]).abs() 51 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 52 | perm = perm.data 53 | fg_sorted = fg[perm] 54 | total_loss += torch.dot(errors_sorted, lovasz_grad(fg_sorted)) 55 | cnt += 1 56 | return total_loss / cnt 57 | 58 | 59 | class LovaszSoftmax(nn.Module): 60 | """ 61 | Multi-class Lovasz-Softmax loss 62 | logits: [B, C, H, W] class logits at each prediction (between -\infty and \infty) 63 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 64 | ignore_index: void class labels 65 | only_present: average only on classes present in ground truth 66 | """ 67 | def __init__(self, ignore_index=None, only_present=True): 68 | super().__init__() 69 | self.ignore_index = ignore_index 70 | self.only_present = only_present 71 | 72 | def forward(self, logits, labels): 73 | probas = F.softmax(logits, dim=1) 74 | total_loss = 0 75 | batch_size = logits.shape[0] 76 | for prb, lbl in zip(probas, labels): 77 | total_loss += lovasz_softmax_flat(prb, lbl, self.ignore_index, self.only_present) 78 | return total_loss / batch_size 79 | -------------------------------------------------------------------------------- /src/losses/multi/ohem_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | # Adapted from OCNet Repository (https://github.com/PkuRainBow/OCNet) 8 | class OhemCrossEntropy2d(nn.Module): 9 | def __init__(self, thresh=0.6, min_kept=0, weight=None, ignore_index=255): 10 | super().__init__() 11 | self.ignore_label = ignore_index 12 | self.thresh = float(thresh) 13 | self.min_kept = int(min_kept) 14 | self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index) 15 | 16 | def forward(self, predict, target): 17 | """ 18 | Args: 19 | predict:(n, c, h, w) 20 | target:(n, h, w) 21 | """ 22 | 23 | n, c, h, w = predict.size() 24 | input_label = target.data.cpu().numpy().ravel().astype(np.int32) 25 | x = np.rollaxis(predict.data.cpu().numpy(), 1).reshape((c, -1)) 26 | input_prob = np.exp(x - x.max(axis=0).reshape((1, -1))) 27 | input_prob /= input_prob.sum(axis=0).reshape((1, -1)) 28 | 29 | valid_flag = input_label != self.ignore_label 30 | valid_inds = np.where(valid_flag)[0] 31 | label = input_label[valid_flag] 32 | num_valid = valid_flag.sum() 33 | if self.min_kept >= num_valid: 34 | print('Labels: {}'.format(num_valid)) 35 | elif num_valid > 0: 36 | prob = input_prob[:, valid_flag] 37 | pred = prob[label, np.arange(len(label), dtype=np.int32)] 38 | threshold = self.thresh 39 | if self.min_kept > 0: 40 | index = pred.argsort() 41 | threshold_index = index[min(len(index), self.min_kept) - 1] 42 | if pred[threshold_index] > self.thresh: 43 | threshold = pred[threshold_index] 44 | kept_flag = pred <= threshold 45 | valid_inds = valid_inds[kept_flag] 46 | print('hard ratio: {} = {} / {} '.format(round(len(valid_inds)/num_valid, 4), len(valid_inds), num_valid)) 47 | 48 | label = input_label[valid_inds].copy() 49 | input_label.fill(self.ignore_label) 50 | input_label[valid_inds] = label 51 | print(np.sum(input_label != self.ignore_label)) 52 | target = torch.from_numpy(input_label.reshape(target.size())).long().cuda() 53 | 54 | return self.criterion(predict, target) 55 | -------------------------------------------------------------------------------- /src/losses/multi/softiou_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SoftIoULoss(nn.Module): 7 | def __init__(self, n_classes): 8 | super(SoftIoULoss, self).__init__() 9 | self.n_classes = n_classes 10 | 11 | @staticmethod 12 | def to_one_hot(tensor, n_classes): 13 | n, h, w = tensor.size() 14 | one_hot = torch.zeros(n, n_classes, h, w).scatter_(1, tensor.view(n, 1, h, w), 1) 15 | return one_hot 16 | 17 | def forward(self, logit, target): 18 | # logit => N x Classes x H x W 19 | # target => N x H x W 20 | 21 | N = len(logit) 22 | 23 | pred = F.softmax(logit, dim=1) 24 | target_onehot = self.to_one_hot(target, self.n_classes) 25 | 26 | # Numerator Product 27 | inter = pred * target_onehot 28 | # Sum over all pixels N x C x H x W => N x C 29 | inter = inter.view(N, self.n_classes, -1).sum(2) 30 | 31 | # Denominator 32 | union = pred + target_onehot - (pred * target_onehot) 33 | # Sum over all pixels N x C x H x W => N x C 34 | union = union.view(N, self.n_classes, -1).sum(2) 35 | 36 | loss = inter / (union + 1e-16) 37 | 38 | # Return average loss over classes and batch 39 | return -loss.mean() 40 | -------------------------------------------------------------------------------- /src/losses/multi/sym_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SoftCrossEntropy(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, logits, labels, valid_mask=None): 11 | if valid_mask is not None: 12 | loss = 0 13 | batch_size = logits.shape[0] 14 | for logit, lbl, val_msk in zip(logits, labels, valid_mask): 15 | logit = logit[:, val_msk] 16 | lbl = lbl[:, val_msk] 17 | loss -= torch.mean(torch.mul(F.log_softmax(logit, dim=0), F.softmax(lbl, dim=0))) 18 | return loss / batch_size 19 | else: 20 | return torch.mean(torch.mul(F.log_softmax(logits, dim=1), F.softmax(labels, dim=1))) 21 | 22 | 23 | class KlLoss(nn.Module): 24 | def __init__(self): 25 | super().__init__() 26 | 27 | def forward(self, logits, labels, valid_mask=None): 28 | if valid_mask is not None: 29 | loss = 0 30 | batch_size = logits.shape[0] 31 | for logit, lbl, val_msk in zip(logits, labels, valid_mask): 32 | logit = logit[:, val_msk] 33 | lbl = lbl[:, val_msk] 34 | loss += torch.mean(F.kl_div(F.log_softmax(logit, dim=0), F.softmax(lbl, dim=0), reduction='none')) 35 | return loss / batch_size 36 | else: 37 | return torch.mean(F.kl_div(F.log_softmax(logits, dim=1), F.softmax(labels, dim=1), reduction='none')) 38 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyoki-mtl/pytorch-segmentation/ddf7ce97114afa0f18902ba4a8f6dd6db581bcec/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/common.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch.nn as nn 3 | 4 | 5 | class _ActivatedBatchNorm(nn.Module): 6 | def __init__(self, num_features, activation='relu', slope=0.01, **kwargs): 7 | super().__init__() 8 | self.bn = nn.BatchNorm2d(num_features, **kwargs) 9 | if activation == 'relu': 10 | self.act = nn.ReLU(inplace=True) 11 | elif activation == 'leaky_relu': 12 | self.act = nn.LeakyReLU(negative_slope=slope, inplace=True) 13 | elif activation == 'elu': 14 | self.act = nn.ELU(inplace=True) 15 | else: 16 | self.act = None 17 | 18 | def forward(self, x): 19 | x = self.bn(x) 20 | if self.act: 21 | x = self.act(x) 22 | return x 23 | 24 | 25 | class SeparableConv2d(nn.Module): 26 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, relu_first=True): 27 | super().__init__() 28 | depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, 29 | stride=stride, padding=dilation, 30 | dilation=dilation, groups=inplanes, bias=False) 31 | bn_depth = nn.BatchNorm2d(inplanes) 32 | pointwise = nn.Conv2d(inplanes, planes, 1, bias=False) 33 | bn_point = nn.BatchNorm2d(planes) 34 | 35 | if relu_first: 36 | self.block = nn.Sequential(OrderedDict([('relu', nn.ReLU()), 37 | ('depthwise', depthwise), 38 | ('bn_depth', bn_depth), 39 | ('pointwise', pointwise), 40 | ('bn_point', bn_point) 41 | ])) 42 | else: 43 | self.block = nn.Sequential(OrderedDict([('depthwise', depthwise), 44 | ('bn_depth', bn_depth), 45 | ('relu1', nn.ReLU()), 46 | ('pointwise', pointwise), 47 | ('bn_point', bn_point), 48 | ('relu2', nn.ReLU()) 49 | ])) 50 | 51 | def forward(self, x): 52 | return self.block(x) 53 | 54 | 55 | # import os 56 | # actbn_env = os.environ.get('INPLACE_ABN') 57 | # if actbn_env: 58 | # from .inplace_abn import InPlaceABN 59 | # ActivatedBatchNorm = InPlaceABN 60 | # else: 61 | # ActivatedBatchNorm = _ActivatedBatchNorm 62 | 63 | ActivatedBatchNorm = _ActivatedBatchNorm 64 | -------------------------------------------------------------------------------- /src/models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .common import ActivatedBatchNorm, SeparableConv2d 5 | from .ibn import ImprovedIBNaDecoderBlock 6 | from .scse import SELayer, SCSEBlock 7 | from .oc import BaseOC 8 | 9 | 10 | class DecoderUnetSCSE(nn.Module): 11 | def __init__(self, in_channels, middle_channels, out_channels): 12 | super().__init__() 13 | self.block = nn.Sequential( 14 | nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1), 15 | ActivatedBatchNorm(middle_channels), 16 | SCSEBlock(middle_channels), 17 | nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2, padding=1) 18 | ) 19 | 20 | def forward(self, *args): 21 | x = torch.cat(args, 1) 22 | return self.block(x) 23 | 24 | 25 | class DecoderUnetSEIBN(nn.Module): 26 | def __init__(self, in_channels, middle_channels, out_channels): 27 | super().__init__() 28 | self.block = nn.Sequential( 29 | SELayer(in_channels), 30 | ImprovedIBNaDecoderBlock(in_channels, out_channels) 31 | ) 32 | 33 | def forward(self, *args): 34 | x = torch.cat(args, 1) 35 | return self.block(x) 36 | 37 | 38 | class DecoderUnetOC(nn.Module): 39 | def __init__(self, in_channels, middle_channels, out_channels): 40 | super().__init__() 41 | self.block = nn.Sequential( 42 | nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1), 43 | ActivatedBatchNorm(middle_channels), 44 | BaseOC(in_channels=middle_channels, 45 | out_channels=middle_channels, 46 | dropout=0.2), 47 | nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2, padding=1), 48 | ) 49 | 50 | def forward(self, *args): 51 | x = torch.cat(args, 1) 52 | return self.block(x) 53 | 54 | 55 | class DecoderSPP(nn.Module): 56 | def __init__(self): 57 | super().__init__() 58 | self.conv = nn.Conv2d(256, 48, 1, bias=False) 59 | self.bn = nn.BatchNorm2d(48) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.sep1 = SeparableConv2d(304, 256, relu_first=False) 62 | self.sep2 = SeparableConv2d(256, 256, relu_first=False) 63 | 64 | def forward(self, x, low_level_feat): 65 | x = F.interpolate(x, size=low_level_feat.shape[2:], mode='bilinear', align_corners=True) 66 | low_level_feat = self.conv(low_level_feat) 67 | low_level_feat = self.bn(low_level_feat) 68 | low_level_feat = self.relu(low_level_feat) 69 | x = torch.cat((x, low_level_feat), dim=1) 70 | x = self.sep1(x) 71 | x = self.sep2(x) 72 | return x 73 | 74 | 75 | def create_decoder(dec_type): 76 | if dec_type == 'unet_scse': 77 | return DecoderUnetSCSE 78 | elif dec_type == 'unet_seibn': 79 | return DecoderUnetSEIBN 80 | elif dec_type == 'unet_oc': 81 | return DecoderUnetOC 82 | else: 83 | raise NotImplementedError 84 | -------------------------------------------------------------------------------- /src/models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | import pretrainedmodels 4 | from .xception import Xception65 5 | from .mobilenet import MobileNetV2 6 | 7 | 8 | def resnet(name, pretrained=False): 9 | def get_channels(layer): 10 | block = layer[-1] 11 | if isinstance(block, models.resnet.BasicBlock): 12 | return block.conv2.out_channels 13 | elif isinstance(block, models.resnet.Bottleneck): 14 | return block.conv3.out_channels 15 | raise RuntimeError("unknown resnet block: {}".format(block)) 16 | 17 | if name == 'resnet18': 18 | resnet = models.resnet18(pretrained=pretrained) 19 | elif name == 'resnet34': 20 | resnet = models.resnet34(pretrained=pretrained) 21 | elif name == 'resnet50': 22 | resnet = models.resnet50(pretrained=pretrained) 23 | elif name == 'resnet101': 24 | resnet = models.resnet101(pretrained=pretrained) 25 | elif name == 'resnet152': 26 | resnet = models.resnet152(pretrained=pretrained) 27 | else: 28 | return NotImplemented 29 | 30 | layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) 31 | layer0.out_channels = resnet.bn1.num_features 32 | resnet.layer1.out_channels = get_channels(resnet.layer1) 33 | resnet.layer2.out_channels = get_channels(resnet.layer2) 34 | resnet.layer3.out_channels = get_channels(resnet.layer3) 35 | resnet.layer4.out_channels = get_channels(resnet.layer4) 36 | return [layer0, resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4] 37 | 38 | 39 | def resnext(name, pretrained=False): 40 | if name in ['resnext101_32x4d', 'resnext101_64x4d']: 41 | pretrained = 'imagenet' if pretrained else None 42 | resnext = pretrainedmodels.__dict__[name](num_classes=1000, pretrained=pretrained) 43 | else: 44 | return NotImplemented 45 | 46 | layer0 = nn.Sequential(resnext.features[0], 47 | resnext.features[1], 48 | resnext.features[2], 49 | resnext.features[3]) 50 | layer1 = resnext.features[4] 51 | layer2 = resnext.features[5] 52 | layer3 = resnext.features[6] 53 | layer4 = resnext.features[7] 54 | 55 | layer0.out_channels = 64 56 | layer1.out_channels = 256 57 | layer2.out_channels = 512 58 | layer3.out_channels = 1024 59 | layer4.out_channels = 2048 60 | return [layer0, layer1, layer2, layer3, layer4] 61 | 62 | 63 | def se_net(name, pretrained=False): 64 | if name in ['se_resnet50', 'se_resnet101', 'se_resnet152', 65 | 'se_resnext50_32x4d', 'se_resnext101_32x4d', 'senet154']: 66 | pretrained = 'imagenet' if pretrained else None 67 | senet = pretrainedmodels.__dict__[name](num_classes=1000, pretrained=pretrained) 68 | else: 69 | return NotImplemented 70 | 71 | layer0 = senet.layer0 72 | layer1 = senet.layer1 73 | layer2 = senet.layer2 74 | layer3 = senet.layer3 75 | layer4 = senet.layer4 76 | 77 | layer0.out_channels = senet.layer1[0].conv1.in_channels 78 | layer1.out_channels = senet.layer1[-1].conv3.out_channels 79 | layer2.out_channels = senet.layer2[-1].conv3.out_channels 80 | layer3.out_channels = senet.layer3[-1].conv3.out_channels 81 | layer4.out_channels = senet.layer4[-1].conv3.out_channels 82 | 83 | return [layer0, layer1, layer2, layer3, layer4] 84 | 85 | 86 | def create_encoder(enc_type, output_stride=8, pretrained=True): 87 | if enc_type.startswith('resnet'): 88 | return resnet(enc_type, pretrained) 89 | elif enc_type.startswith('resnext'): 90 | return resnext(enc_type, pretrained) 91 | elif enc_type.startswith('se'): 92 | return se_net(enc_type, pretrained) 93 | elif enc_type == 'xception65': 94 | return Xception65(output_stride) 95 | elif enc_type == 'mobilenetv2': 96 | return MobileNetV2(pretrained) 97 | else: 98 | raise NotImplementedError 99 | -------------------------------------------------------------------------------- /src/models/ibn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .common import ActivatedBatchNorm 4 | 5 | 6 | class IBN(nn.Module): 7 | def __init__(self, planes): 8 | super().__init__() 9 | half1 = int(planes / 2) 10 | self.half = half1 11 | half2 = planes - half1 12 | self.IN = nn.Sequential(nn.InstanceNorm2d(half1, affine=True), 13 | nn.ReLU(inplace=True)) 14 | self.BN = ActivatedBatchNorm(half2) 15 | 16 | def forward(self, x): 17 | split = torch.split(x, self.half, 1) 18 | out1 = self.IN(split[0].contiguous()) 19 | out2 = self.BN(split[1].contiguous()) 20 | out = torch.cat((out1, out2), 1) 21 | return out 22 | 23 | 24 | class ImprovedIBNaDecoderBlock(nn.Module): 25 | def __init__(self, in_channels, out_channels): 26 | super().__init__() 27 | 28 | self.block = nn.Sequential( 29 | nn.Conv2d(in_channels, in_channels // 4, 1), 30 | IBN(in_channels // 4), 31 | nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 4, stride=2, padding=1), 32 | ActivatedBatchNorm(in_channels // 4), 33 | nn.Conv2d(in_channels // 4, out_channels, 1), 34 | ActivatedBatchNorm(out_channels) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.block(x) 39 | -------------------------------------------------------------------------------- /src/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from collections import OrderedDict 5 | 6 | 7 | class ExpandedConv(nn.Module): 8 | def __init__(self, in_channels, out_channels, stride=1, dilation=1, 9 | expand_ratio=6, skip_connection=False): 10 | super().__init__() 11 | 12 | self.stride = stride 13 | self.kernel_size = 3 14 | self.dilation = dilation 15 | self.expand_ratio = expand_ratio 16 | self.skip_connection = skip_connection 17 | middle_channels = in_channels * expand_ratio 18 | 19 | if self.expand_ratio != 1: 20 | # pointwise 21 | self.expand = nn.Sequential(OrderedDict( 22 | [('conv', nn.Conv2d(in_channels, middle_channels, 1, bias=False)), 23 | ('bn', nn.BatchNorm2d(middle_channels)), 24 | ('relu', nn.ReLU6(inplace=True)) 25 | ])) 26 | 27 | # depthwise 28 | self.depthwise = nn.Sequential(OrderedDict( 29 | [('conv', nn.Conv2d(middle_channels, middle_channels, 3, stride, dilation, dilation, groups=middle_channels, bias=False)), 30 | ('bn', nn.BatchNorm2d(middle_channels)), 31 | ('relu', nn.ReLU6(inplace=True)) 32 | ])) 33 | 34 | # project 35 | self.project = nn.Sequential(OrderedDict( 36 | [('conv', nn.Conv2d(middle_channels, out_channels, 1, bias=False)), 37 | ('bn', nn.BatchNorm2d(out_channels)) 38 | ])) 39 | 40 | def forward(self, x): 41 | if self.expand_ratio != 1: 42 | residual = self.project(self.depthwise(self.expand(x))) 43 | else: 44 | residual = self.project(self.depthwise(x)) 45 | 46 | if self.skip_connection: 47 | outputs = x + residual 48 | else: 49 | outputs = residual 50 | return outputs 51 | 52 | 53 | class MobileNetV2(nn.Module): 54 | def __init__(self, pretrained=False, model_path='../model/mobilenetv2_encoder/model.pth'): 55 | super().__init__() 56 | 57 | self.conv = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 58 | self.bn = nn.BatchNorm2d(32) 59 | self.relu = nn.ReLU6() 60 | 61 | self.block0 = ExpandedConv(32, 16, expand_ratio=1) 62 | self.block1 = ExpandedConv(16, 24, stride=2) 63 | self.block2 = ExpandedConv(24, 24, skip_connection=True) 64 | self.block3 = ExpandedConv(24, 32, stride=2) 65 | self.block4 = ExpandedConv(32, 32, skip_connection=True) 66 | self.block5 = ExpandedConv(32, 32, skip_connection=True) 67 | self.block6 = ExpandedConv(32, 64) 68 | self.block7 = ExpandedConv(64, 64, dilation=2, skip_connection=True) 69 | self.block8 = ExpandedConv(64, 64, dilation=2, skip_connection=True) 70 | self.block9 = ExpandedConv(64, 64, dilation=2, skip_connection=True) 71 | self.block10 = ExpandedConv(64, 96, dilation=2) 72 | self.block11 = ExpandedConv(96, 96, dilation=2, skip_connection=True) 73 | self.block12 = ExpandedConv(96, 96, dilation=2, skip_connection=True) 74 | self.block13 = ExpandedConv(96, 160, dilation=2) 75 | self.block14 = ExpandedConv(160, 160, dilation=4, skip_connection=True) 76 | self.block15 = ExpandedConv(160, 160, dilation=4, skip_connection=True) 77 | self.block16 = ExpandedConv(160, 320, dilation=4) 78 | 79 | if pretrained: 80 | self.load_pretrained_model(model_path) 81 | 82 | def forward(self, x): 83 | x = self.conv(x) 84 | x = self.bn(x) 85 | x = self.relu(x) 86 | x = self.block0(x) 87 | x = self.block1(x) 88 | x = self.block2(x) 89 | x = self.block3(x) 90 | x = self.block4(x) 91 | x = self.block5(x) 92 | x = self.block6(x) 93 | x = self.block7(x) 94 | x = self.block8(x) 95 | x = self.block9(x) 96 | x = self.block10(x) 97 | x = self.block11(x) 98 | x = self.block12(x) 99 | x = self.block13(x) 100 | x = self.block14(x) 101 | x = self.block15(x) 102 | x = self.block16(x) 103 | return x 104 | 105 | def load_pretrained_model(self, model_path): 106 | self.load_state_dict(torch.load(model_path)) 107 | print(f'Load from {model_path}!') 108 | -------------------------------------------------------------------------------- /src/models/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .common import ActivatedBatchNorm 6 | from .encoder import create_encoder 7 | from .decoder import create_decoder 8 | from .spp import create_spp, create_mspp 9 | from .tta import SegmentatorTTA 10 | 11 | 12 | class EncoderDecoderNet(nn.Module, SegmentatorTTA): 13 | def __init__(self, output_channels=19, enc_type='resnet50', dec_type='unet_scse', 14 | num_filters=16, pretrained=False): 15 | super().__init__() 16 | self.output_channels = output_channels 17 | self.enc_type = enc_type 18 | self.dec_type = dec_type 19 | 20 | assert enc_type in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 21 | 'resnext101_32x4d', 'resnext101_64x4d', 22 | 'se_resnet50', 'se_resnet101', 'se_resnet152', 23 | 'se_resnext50_32x4d', 'se_resnext101_32x4d', 'senet154'] 24 | assert dec_type in ['unet_scse', 'unet_seibn', 'unet_oc'] 25 | 26 | encoder = create_encoder(enc_type, pretrained) 27 | Decoder = create_decoder(dec_type) 28 | 29 | self.encoder1 = encoder[0] 30 | self.encoder2 = encoder[1] 31 | self.encoder3 = encoder[2] 32 | self.encoder4 = encoder[3] 33 | self.encoder5 = encoder[4] 34 | 35 | self.pool = nn.MaxPool2d(2, 2) 36 | self.center = Decoder(self.encoder5.out_channels, num_filters * 32 * 2, num_filters * 32) 37 | 38 | self.decoder5 = Decoder(self.encoder5.out_channels + num_filters * 32, num_filters * 32 * 2, 39 | num_filters * 16) 40 | self.decoder4 = Decoder(self.encoder4.out_channels + num_filters * 16, num_filters * 16 * 2, 41 | num_filters * 8) 42 | self.decoder3 = Decoder(self.encoder3.out_channels + num_filters * 8, num_filters * 8 * 2, num_filters * 4) 43 | self.decoder2 = Decoder(self.encoder2.out_channels + num_filters * 4, num_filters * 4 * 2, num_filters * 2) 44 | self.decoder1 = Decoder(self.encoder1.out_channels + num_filters * 2, num_filters * 2 * 2, num_filters) 45 | 46 | self.logits = nn.Sequential( 47 | nn.Conv2d(num_filters * (16 + 8 + 4 + 2 + 1), 64, kernel_size=1, padding=0), 48 | ActivatedBatchNorm(64), 49 | nn.Conv2d(64, self.output_channels, kernel_size=1) 50 | ) 51 | 52 | def forward(self, x): 53 | img_size = x.shape[2:] 54 | 55 | e1 = self.encoder1(x) 56 | e2 = self.encoder2(e1) 57 | e3 = self.encoder3(e2) 58 | e4 = self.encoder4(e3) 59 | e5 = self.encoder5(e4) 60 | 61 | c = self.center(self.pool(e5)) 62 | e1_up = F.interpolate(e1, scale_factor=2, mode='bilinear', align_corners=False) 63 | 64 | d5 = self.decoder5(c, e5) 65 | d4 = self.decoder4(d5, e4) 66 | d3 = self.decoder3(d4, e3) 67 | d2 = self.decoder2(d3, e2) 68 | d1 = self.decoder1(d2, e1_up) 69 | 70 | u5 = F.interpolate(d5, img_size, mode='bilinear', align_corners=False) 71 | u4 = F.interpolate(d4, img_size, mode='bilinear', align_corners=False) 72 | u3 = F.interpolate(d3, img_size, mode='bilinear', align_corners=False) 73 | u2 = F.interpolate(d2, img_size, mode='bilinear', align_corners=False) 74 | 75 | # Hyper column 76 | d = torch.cat((d1, u2, u3, u4, u5), 1) 77 | logits = self.logits(d) 78 | 79 | return logits 80 | 81 | 82 | class SPPNet(nn.Module, SegmentatorTTA): 83 | def __init__(self, output_channels=19, enc_type='xception65', dec_type='aspp', output_stride=8): 84 | super().__init__() 85 | self.output_channels = output_channels 86 | self.enc_type = enc_type 87 | self.dec_type = dec_type 88 | 89 | assert enc_type in ['xception65', 'mobilenetv2'] 90 | assert dec_type in ['oc_base', 'oc_asp', 'spp', 'aspp', 'maspp'] 91 | 92 | self.encoder = create_encoder(enc_type, output_stride=output_stride, pretrained=False) 93 | if enc_type == 'mobilenetv2': 94 | self.spp = create_mspp(dec_type) 95 | else: 96 | self.spp, self.decoder = create_spp(dec_type, output_stride=output_stride) 97 | self.logits = nn.Conv2d(256, output_channels, 1) 98 | 99 | def forward(self, inputs): 100 | if self.enc_type == 'mobilenetv2': 101 | x = self.encoder(inputs) 102 | x = self.spp(x) 103 | x = self.logits(x) 104 | return x 105 | else: 106 | x, low_level_feat = self.encoder(inputs) 107 | x = self.spp(x) 108 | x = self.decoder(x, low_level_feat) 109 | x = self.logits(x) 110 | return x 111 | 112 | def update_bn_eps(self): 113 | for m in self.encoder.named_modules(): 114 | if isinstance(m[1], nn.BatchNorm2d): 115 | m[1].eps = 1e-3 116 | 117 | def freeze_bn(self): 118 | for m in self.modules(): 119 | if isinstance(m, nn.modules.batchnorm._BatchNorm): 120 | m.eval() 121 | # for p in m.parameters(): 122 | # p.requires_grad = False 123 | 124 | def get_1x_lr_params(self): 125 | for p in self.encoder.parameters(): 126 | yield p 127 | 128 | def get_10x_lr_params(self): 129 | modules = [self.spp, self.logits] 130 | if hasattr(self, 'decoder'): 131 | modules.append(self.decoder) 132 | 133 | for module in modules: 134 | for p in module.parameters(): 135 | yield p 136 | -------------------------------------------------------------------------------- /src/models/oc.py: -------------------------------------------------------------------------------- 1 | """ 2 | OCNet: Object Context Network for Scene Parsing 3 | https://github.com/PkuRainBow/OCNet 4 | """ 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from .common import ActivatedBatchNorm 10 | 11 | 12 | class SelfAttentionBlock2D(nn.Module): 13 | """ 14 | The basic implementation for self-attention block/non-local block 15 | Input: 16 | N X C X H X W 17 | Parameters: 18 | in_channels : the dimension of the input feature map 19 | key_channels : the dimension after the key/query transform 20 | value_channels : the dimension after the value transform 21 | scale : choose the scale to downsample the input feature maps (save memory cost) 22 | Return: 23 | N X C X H X W 24 | position-aware context features.(w/o concate or add with the input) 25 | """ 26 | 27 | def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1): 28 | super().__init__() 29 | self.scale = scale 30 | self.in_channels = in_channels 31 | self.out_channels = out_channels 32 | self.key_channels = key_channels 33 | self.value_channels = value_channels 34 | if out_channels is None: 35 | self.out_channels = in_channels 36 | self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) 37 | self.f_key = nn.Sequential( 38 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, kernel_size=1), 39 | ActivatedBatchNorm(self.key_channels) 40 | ) 41 | self.f_query = self.f_key 42 | self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels, kernel_size=1) 43 | self.W = nn.Conv2d(in_channels=self.value_channels, out_channels=self.out_channels, kernel_size=1) 44 | nn.init.constant_(self.W.weight, 0) 45 | nn.init.constant_(self.W.bias, 0) 46 | 47 | def forward(self, x): 48 | batch_size, h, w = x.size(0), x.size(2), x.size(3) 49 | if self.scale > 1: 50 | x = self.pool(x) 51 | 52 | value = self.f_value(x).view(batch_size, self.value_channels, -1) 53 | value = value.permute(0, 2, 1) # b, h*w, v 54 | query = self.f_query(x).view(batch_size, self.key_channels, -1) 55 | query = query.permute(0, 2, 1) # b, h*w, k 56 | key = self.f_key(x).view(batch_size, self.key_channels, -1) # b, k, h*w 57 | 58 | sim_map = torch.matmul(query, key) 59 | sim_map = (self.key_channels ** -.5) * sim_map 60 | sim_map = F.softmax(sim_map, dim=-1) # b * h*w * h*w 61 | 62 | context = torch.matmul(sim_map, value) # b * h*w * v 63 | context = context.permute(0, 2, 1).contiguous() 64 | context = context.view(batch_size, self.value_channels, *x.size()[2:]) 65 | context = self.W(context) 66 | if self.scale > 1: 67 | context = F.interpolate(context, size=(h, w), mode='bilinear', align_corners=True) 68 | return context 69 | 70 | 71 | class BaseOC_Context(nn.Module): 72 | """ 73 | Output only the context features. 74 | Parameters: 75 | in_features / out_features: the channels of the input / output feature maps. 76 | dropout: specify the dropout ratio 77 | fusion: We provide two different fusion method, "concat" or "add" 78 | size: we find that directly learn the attention weights on even 1/8 feature maps is hard. 79 | Return: 80 | features after "concat" or "add" 81 | """ 82 | 83 | def __init__(self, in_channels, out_channels, key_channels, value_channels, dropout=0.05, sizes=(1,)): 84 | super().__init__() 85 | self.stages = nn.ModuleList( 86 | [SelfAttentionBlock2D(in_channels, key_channels, value_channels, out_channels, size) for size in sizes]) 87 | self.conv_bn_dropout = nn.Sequential( 88 | nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0), 89 | ActivatedBatchNorm(out_channels), 90 | nn.Dropout2d(dropout) 91 | ) 92 | 93 | def forward(self, feats): 94 | priors = [stage(feats) for stage in self.stages] 95 | context = priors[0] 96 | for i in range(1, len(priors)): 97 | context += priors[i] 98 | output = self.conv_bn_dropout(context) 99 | return output 100 | 101 | 102 | class BaseOC(nn.Module): 103 | def __init__(self, in_channels=2048, out_channels=256, dropout=0.05): 104 | super().__init__() 105 | self.block = nn.Sequential( 106 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 107 | ActivatedBatchNorm(out_channels), 108 | BaseOC_Context(in_channels=out_channels, out_channels=out_channels, 109 | key_channels=out_channels // 2, value_channels=out_channels // 2, dropout=dropout)) 110 | 111 | def forward(self, x): 112 | return self.block(x) 113 | 114 | 115 | class ASPOC(nn.Module): 116 | def __init__(self, in_channels=2048, out_channels=256, output_stride=8): 117 | super().__init__() 118 | if output_stride == 16: 119 | dilations = [6, 12, 18] 120 | elif output_stride == 8: 121 | dilations = [12, 24, 36] 122 | else: 123 | raise NotImplementedError 124 | 125 | self.context = nn.Sequential( 126 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, dilation=1, bias=True), 127 | ActivatedBatchNorm(out_channels), 128 | BaseOC_Context(in_channels=out_channels, out_channels=out_channels, 129 | key_channels=out_channels // 2, value_channels=out_channels, 130 | dropout=0, sizes=([2]))) 131 | self.conv2 = nn.Sequential( 132 | nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=False), 133 | ActivatedBatchNorm(out_channels)) 134 | self.conv3 = nn.Sequential( 135 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilations[0], dilation=dilations[0], 136 | bias=False), 137 | ActivatedBatchNorm(out_channels)) 138 | self.conv4 = nn.Sequential( 139 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilations[1], dilation=dilations[1], 140 | bias=False), 141 | ActivatedBatchNorm(out_channels)) 142 | self.conv5 = nn.Sequential( 143 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilations[2], dilation=dilations[2], 144 | bias=False), 145 | ActivatedBatchNorm(out_channels)) 146 | self.conv_bn_dropout = nn.Sequential( 147 | nn.Conv2d(out_channels * 5, out_channels, kernel_size=1, padding=0, dilation=1, bias=False), 148 | ActivatedBatchNorm(out_channels), 149 | nn.Dropout2d(0.1) 150 | ) 151 | 152 | def forward(self, x): 153 | _, _, h, w = x.size() 154 | feat1 = self.context(x) 155 | feat2 = self.conv2(x) 156 | feat3 = self.conv3(x) 157 | feat4 = self.conv4(x) 158 | feat5 = self.conv5(x) 159 | 160 | out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1) 161 | output = self.conv_bn_dropout(out) 162 | return output 163 | -------------------------------------------------------------------------------- /src/models/scse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SELayer(nn.Module): 6 | def __init__(self, channel, reduction=16): 7 | super().__init__() 8 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 9 | self.fc = nn.Sequential( 10 | nn.Linear(channel, int(channel / reduction), bias=False), 11 | nn.ReLU(inplace=True), 12 | nn.Linear(int(channel / reduction), channel, bias=False), 13 | nn.Sigmoid() 14 | ) 15 | 16 | def forward(self, x): 17 | b, c, _, _ = x.size() 18 | y = self.avg_pool(x).view(b, c) 19 | y = self.fc(y).view(b, c, 1, 1) 20 | return x * y.expand_as(x) 21 | 22 | 23 | class SCSEBlock(nn.Module): 24 | def __init__(self, channel, reduction=16): 25 | super().__init__() 26 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 27 | self.channel_excitation = nn.Sequential(nn.Linear(channel, int(channel // reduction)), 28 | nn.ReLU(inplace=True), 29 | nn.Linear(int(channel // reduction), channel)) 30 | self.spatial_se = nn.Conv2d(channel, 1, kernel_size=1, 31 | stride=1, padding=0, bias=False) 32 | 33 | def forward(self, x): 34 | bahs, chs, _, _ = x.size() 35 | 36 | # Returns a new tensor with the same data as the self tensor but of a different size. 37 | chn_se = self.avg_pool(x).view(bahs, chs) 38 | chn_se = torch.sigmoid(self.channel_excitation(chn_se).view(bahs, chs, 1, 1)) 39 | chn_se = torch.mul(x, chn_se) 40 | 41 | spa_se = torch.sigmoid(self.spatial_se(x)) 42 | spa_se = torch.mul(x, spa_se) 43 | return torch.add(chn_se, 1, spa_se) 44 | -------------------------------------------------------------------------------- /src/models/spp.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from .common import ActivatedBatchNorm, SeparableConv2d 6 | from .oc import BaseOC, ASPOC 7 | 8 | 9 | class SPP(nn.Module): 10 | def __init__(self, in_channels=2048, out_channels=256, pyramids=(1, 2, 3, 6)): 11 | super().__init__() 12 | stages = [] 13 | for p in pyramids: 14 | stages.append(nn.Sequential( 15 | nn.AdaptiveAvgPool2d(p), 16 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 17 | ActivatedBatchNorm(out_channels) 18 | )) 19 | self.stages = nn.ModuleList(stages) 20 | self.bottleneck = nn.Sequential( 21 | nn.Conv2d(in_channels + out_channels * len(pyramids), out_channels, kernel_size=1), 22 | ActivatedBatchNorm(out_channels) 23 | ) 24 | 25 | def forward(self, x): 26 | x_size = x.size() 27 | out = [x] 28 | for stage in self.stages: 29 | out.append(F.interpolate(stage(x), size=x_size[2:], mode='bilinear', align_corners=False)) 30 | out = self.bottleneck(torch.cat(out, 1)) 31 | return out 32 | 33 | 34 | class ASPP(nn.Module): 35 | def __init__(self, in_channels=2048, out_channels=256, output_stride=8): 36 | super().__init__() 37 | if output_stride == 16: 38 | dilations = [6, 12, 18] 39 | elif output_stride == 8: 40 | dilations = [12, 24, 36] 41 | else: 42 | raise NotImplementedError 43 | # dilations = [6, 12, 18] 44 | 45 | self.aspp0 = nn.Sequential(OrderedDict([('conv', nn.Conv2d(in_channels, out_channels, 1, bias=False)), 46 | ('bn', nn.BatchNorm2d(out_channels)), 47 | ('relu', nn.ReLU(inplace=True))])) 48 | self.aspp1 = SeparableConv2d(in_channels, out_channels, dilation=dilations[0], relu_first=False) 49 | self.aspp2 = SeparableConv2d(in_channels, out_channels, dilation=dilations[1], relu_first=False) 50 | self.aspp3 = SeparableConv2d(in_channels, out_channels, dilation=dilations[2], relu_first=False) 51 | 52 | self.image_pooling = nn.Sequential(OrderedDict([('gap', nn.AdaptiveAvgPool2d((1, 1))), 53 | ('conv', nn.Conv2d(in_channels, out_channels, 1, bias=False)), 54 | ('bn', nn.BatchNorm2d(out_channels)), 55 | ('relu', nn.ReLU(inplace=True))])) 56 | 57 | self.conv = nn.Conv2d(out_channels*5, out_channels, 1, bias=False) 58 | self.bn = nn.BatchNorm2d(out_channels) 59 | self.relu = nn.ReLU(inplace=True) 60 | self.dropout = nn.Dropout2d(p=0.1) 61 | 62 | def forward(self, x): 63 | pool = self.image_pooling(x) 64 | pool = F.interpolate(pool, size=x.shape[2:], mode='bilinear', align_corners=True) 65 | 66 | x0 = self.aspp0(x) 67 | x1 = self.aspp1(x) 68 | x2 = self.aspp2(x) 69 | x3 = self.aspp3(x) 70 | x = torch.cat((pool, x0, x1, x2, x3), dim=1) 71 | 72 | x = self.conv(x) 73 | x = self.bn(x) 74 | x = self.relu(x) 75 | x = self.dropout(x) 76 | 77 | return x 78 | 79 | 80 | class MobileASPP(nn.Module): 81 | def __init__(self): 82 | super().__init__() 83 | self.aspp0 = nn.Sequential(OrderedDict([('conv', nn.Conv2d(320, 256, 1, bias=False)), 84 | ('bn', nn.BatchNorm2d(256)), 85 | ('relu', nn.ReLU(inplace=True))])) 86 | self.image_pooling = nn.Sequential(OrderedDict([('gap', nn.AdaptiveAvgPool2d((1, 1))), 87 | ('conv', nn.Conv2d(320, 256, 1, bias=False)), 88 | ('bn', nn.BatchNorm2d(256)), 89 | ('relu', nn.ReLU(inplace=True))])) 90 | 91 | self.conv = nn.Conv2d(512, 256, 1, bias=False) 92 | self.bn = nn.BatchNorm2d(256) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.dropout = nn.Dropout2d(p=0.1) 95 | 96 | def forward(self, x): 97 | pool = self.image_pooling(x) 98 | pool = F.interpolate(pool, size=x.shape[2:], mode='bilinear', align_corners=True) 99 | 100 | x = self.aspp0(x) 101 | x = torch.cat((pool, x), dim=1) 102 | 103 | x = self.conv(x) 104 | x = self.bn(x) 105 | x = self.relu(x) 106 | x = self.dropout(x) 107 | 108 | return x 109 | 110 | 111 | class SPPDecoder(nn.Module): 112 | def __init__(self, in_channels, reduced_layer_num=48): 113 | super().__init__() 114 | self.conv = nn.Conv2d(in_channels, reduced_layer_num, 1, bias=False) 115 | self.bn = nn.BatchNorm2d(reduced_layer_num) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.sep1 = SeparableConv2d(256+reduced_layer_num, 256, relu_first=False) 118 | self.sep2 = SeparableConv2d(256, 256, relu_first=False) 119 | 120 | def forward(self, x, low_level_feat): 121 | x = F.interpolate(x, size=low_level_feat.shape[2:], mode='bilinear', align_corners=True) 122 | low_level_feat = self.conv(low_level_feat) 123 | low_level_feat = self.bn(low_level_feat) 124 | low_level_feat = self.relu(low_level_feat) 125 | x = torch.cat((x, low_level_feat), dim=1) 126 | x = self.sep1(x) 127 | x = self.sep2(x) 128 | return x 129 | 130 | 131 | def create_spp(dec_type, in_channels=2048, middle_channels=256, output_stride=8): 132 | if dec_type == 'spp': 133 | return SPP(in_channels, middle_channels), SPPDecoder(middle_channels) 134 | elif dec_type == 'aspp': 135 | return ASPP(in_channels, middle_channels, output_stride), SPPDecoder(middle_channels) 136 | elif dec_type == 'oc_base': 137 | return BaseOC(in_channels, middle_channels), SPPDecoder(middle_channels) 138 | elif dec_type in 'oc_asp': 139 | return ASPOC(in_channels, middle_channels, output_stride), SPPDecoder(middle_channels) 140 | else: 141 | raise NotImplementedError 142 | 143 | 144 | def create_mspp(dec_type): 145 | if dec_type == 'spp': 146 | return SPP(320, 256) 147 | elif dec_type == 'aspp': 148 | return ASPP(320, 256, 8) 149 | elif dec_type == 'oc_base': 150 | return BaseOC(320, 256) 151 | elif dec_type == 'oc_asp': 152 | return ASPOC(320, 256, 8) 153 | elif dec_type == 'maspp': 154 | return MobileASPP() 155 | elif dec_type == 'maspp_dec': 156 | return MobileASPP(), SPPDecoder(24, reduced_layer_num=12) 157 | else: 158 | raise NotImplementedError 159 | -------------------------------------------------------------------------------- /src/models/tta.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | class SegmentatorTTA(object): 4 | @staticmethod 5 | def hflip(x): 6 | return x.flip(3) 7 | 8 | @staticmethod 9 | def vflip(x): 10 | return x.flip(2) 11 | 12 | @staticmethod 13 | def trans(x): 14 | return x.transpose(2, 3) 15 | 16 | def pred_resize(self, x, size, net_type='unet'): 17 | h, w = size 18 | if net_type == 'unet': 19 | pred = self.forward(x) 20 | if x.shape[2:] == size: 21 | return pred 22 | else: 23 | return F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=True) 24 | else: 25 | pred = self.forward(F.pad(x, (0, 1, 0, 1))) 26 | return F.interpolate(pred, size=(h+1, w+1), mode='bilinear', align_corners=True)[..., :h, :w] 27 | 28 | def tta(self, x, scales=None, net_type='unet'): 29 | size = x.shape[2:] 30 | if scales is None: 31 | seg_sum = self.pred_resize(x, size, net_type) 32 | seg_sum += self.hflip(self.pred_resize(self.hflip(x), size, net_type)) 33 | return seg_sum / 2 34 | else: 35 | # scale = 1 36 | seg_sum = self.pred_resize(x, size, net_type) 37 | seg_sum += self.hflip(self.pred_resize(self.hflip(x), size, net_type)) 38 | for scale in scales: 39 | scaled = F.interpolate(x, scale_factor=scale, mode='bilinear', align_corners=True) 40 | seg_sum += self.pred_resize(scaled, size, net_type) 41 | seg_sum += self.hflip(self.pred_resize(self.hflip(scaled), size, net_type)) 42 | return seg_sum / ((len(scales) + 1) * 2) 43 | 44 | -------------------------------------------------------------------------------- /src/models/xception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .common import SeparableConv2d 5 | 6 | 7 | class XceptionBlock(nn.Module): 8 | def __init__(self, channel_list, stride=1, dilation=1, skip_connection_type='conv', relu_first=True, low_feat=False): 9 | super().__init__() 10 | 11 | assert len(channel_list) == 4 12 | self.skip_connection_type = skip_connection_type 13 | self.relu_first = relu_first 14 | self.low_feat = low_feat 15 | 16 | if self.skip_connection_type == 'conv': 17 | self.conv = nn.Conv2d(channel_list[0], channel_list[-1], 1, stride=stride, bias=False) 18 | self.bn = nn.BatchNorm2d(channel_list[-1]) 19 | 20 | self.sep_conv1 = SeparableConv2d(channel_list[0], channel_list[1], 21 | dilation=dilation, relu_first=relu_first) 22 | self.sep_conv2 = SeparableConv2d(channel_list[1], channel_list[2], 23 | dilation=dilation, relu_first=relu_first) 24 | self.sep_conv3 = SeparableConv2d(channel_list[2], channel_list[3], 25 | dilation=dilation, relu_first=relu_first, stride=stride) 26 | 27 | def forward(self, inputs): 28 | sc1 = self.sep_conv1(inputs) 29 | sc2 = self.sep_conv2(sc1) 30 | residual = self.sep_conv3(sc2) 31 | 32 | if self.skip_connection_type == 'conv': 33 | shortcut = self.conv(inputs) 34 | shortcut = self.bn(shortcut) 35 | outputs = residual + shortcut 36 | elif self.skip_connection_type == 'sum': 37 | outputs = residual + inputs 38 | elif self.skip_connection_type == 'none': 39 | outputs = residual 40 | else: 41 | raise ValueError('Unsupported skip connection type.') 42 | 43 | if self.low_feat: 44 | return outputs, sc2 45 | else: 46 | return outputs 47 | 48 | 49 | class Xception65(nn.Module): 50 | def __init__(self, output_stride=8): 51 | super().__init__() 52 | 53 | if output_stride == 16: 54 | entry_block3_stride = 2 55 | middle_block_dilation = 1 56 | exit_block_dilations = (1, 2) 57 | elif output_stride == 8: 58 | entry_block3_stride = 1 59 | middle_block_dilation = 2 60 | exit_block_dilations = (2, 4) 61 | else: 62 | raise NotImplementedError 63 | 64 | # Entry flow 65 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 66 | self.bn1 = nn.BatchNorm2d(32) 67 | self.relu = nn.ReLU() 68 | 69 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 70 | self.bn2 = nn.BatchNorm2d(64) 71 | 72 | self.block1 = XceptionBlock([64, 128, 128, 128], stride=2) 73 | self.block2 = XceptionBlock([128, 256, 256, 256], stride=2, low_feat=True) 74 | self.block3 = XceptionBlock([256, 728, 728, 728], stride=entry_block3_stride) 75 | 76 | # Middle flow (16 units) 77 | self.block4 = XceptionBlock([728, 728, 728, 728], dilation=middle_block_dilation, skip_connection_type='sum') 78 | self.block5 = XceptionBlock([728, 728, 728, 728], dilation=middle_block_dilation, skip_connection_type='sum') 79 | self.block6 = XceptionBlock([728, 728, 728, 728], dilation=middle_block_dilation, skip_connection_type='sum') 80 | self.block7 = XceptionBlock([728, 728, 728, 728], dilation=middle_block_dilation, skip_connection_type='sum') 81 | self.block8 = XceptionBlock([728, 728, 728, 728], dilation=middle_block_dilation, skip_connection_type='sum') 82 | self.block9 = XceptionBlock([728, 728, 728, 728], dilation=middle_block_dilation, skip_connection_type='sum') 83 | self.block10 = XceptionBlock([728, 728, 728, 728], dilation=middle_block_dilation, skip_connection_type='sum') 84 | self.block11 = XceptionBlock([728, 728, 728, 728], dilation=middle_block_dilation, skip_connection_type='sum') 85 | self.block12 = XceptionBlock([728, 728, 728, 728], dilation=middle_block_dilation, skip_connection_type='sum') 86 | self.block13 = XceptionBlock([728, 728, 728, 728], dilation=middle_block_dilation, skip_connection_type='sum') 87 | self.block14 = XceptionBlock([728, 728, 728, 728], dilation=middle_block_dilation, skip_connection_type='sum') 88 | self.block15 = XceptionBlock([728, 728, 728, 728], dilation=middle_block_dilation, skip_connection_type='sum') 89 | self.block16 = XceptionBlock([728, 728, 728, 728], dilation=middle_block_dilation, skip_connection_type='sum') 90 | self.block17 = XceptionBlock([728, 728, 728, 728], dilation=middle_block_dilation, skip_connection_type='sum') 91 | self.block18 = XceptionBlock([728, 728, 728, 728], dilation=middle_block_dilation, skip_connection_type='sum') 92 | self.block19 = XceptionBlock([728, 728, 728, 728], dilation=middle_block_dilation, skip_connection_type='sum') 93 | 94 | # Exit flow 95 | self.block20 = XceptionBlock([728, 728, 1024, 1024], dilation=exit_block_dilations[0]) 96 | self.block21 = XceptionBlock([1024, 1536, 1536, 2048], dilation=exit_block_dilations[1], 97 | skip_connection_type='none', relu_first=False) 98 | 99 | def forward(self, x): 100 | # Entry flow 101 | x = self.conv1(x) 102 | x = self.bn1(x) 103 | x = self.relu(x) 104 | 105 | x = self.conv2(x) 106 | x = self.bn2(x) 107 | x = self.relu(x) 108 | 109 | x = self.block1(x) 110 | x, low_level_feat = self.block2(x) # b, h//4, w//4, 256 111 | x = self.block3(x) # b, h//8, w//8, 728 112 | 113 | # Middle flow 114 | x = self.block4(x) 115 | x = self.block5(x) 116 | x = self.block6(x) 117 | x = self.block7(x) 118 | x = self.block8(x) 119 | x = self.block9(x) 120 | x = self.block10(x) 121 | x = self.block11(x) 122 | x = self.block12(x) 123 | x = self.block13(x) 124 | x = self.block14(x) 125 | x = self.block15(x) 126 | x = self.block16(x) 127 | x = self.block17(x) 128 | x = self.block18(x) 129 | x = self.block19(x) 130 | 131 | # Exit flow 132 | x = self.block20(x) 133 | x = self.block21(x) 134 | 135 | return x, low_level_feat 136 | -------------------------------------------------------------------------------- /src/start_train.sh: -------------------------------------------------------------------------------- 1 | if [ -z "$(ps -ef | grep '[p]ython train')" ] 2 | then 3 | nohup python train.py $1 & 4 | fi 5 | -------------------------------------------------------------------------------- /src/stop_train.sh: -------------------------------------------------------------------------------- 1 | kill `ps -ef | grep '[p]ython train' | awk '{print $2}'` 2 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import yaml 4 | import numpy as np 5 | import albumentations as albu 6 | from collections import OrderedDict 7 | from pathlib import Path 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.utils.data import DataLoader 14 | 15 | from models.net import EncoderDecoderNet, SPPNet 16 | from losses.multi import MultiClassCriterion 17 | from logger.log import debug_logger 18 | from logger.plot import history_ploter 19 | from utils.optimizer import create_optimizer 20 | from utils.metrics import compute_iou_batch 21 | 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('config_path') 25 | args = parser.parse_args() 26 | config_path = Path(args.config_path) 27 | config = yaml.load(open(config_path)) 28 | net_config = config['Net'] 29 | data_config = config['Data'] 30 | train_config = config['Train'] 31 | loss_config = config['Loss'] 32 | opt_config = config['Optimizer'] 33 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 34 | t_max = opt_config['t_max'] 35 | 36 | max_epoch = train_config['max_epoch'] 37 | batch_size = train_config['batch_size'] 38 | fp16 = train_config['fp16'] 39 | resume = train_config['resume'] 40 | pretrained_path = train_config['pretrained_path'] 41 | 42 | # Network 43 | if 'unet' in net_config['dec_type']: 44 | net_type = 'unet' 45 | model = EncoderDecoderNet(**net_config) 46 | else: 47 | net_type = 'deeplab' 48 | model = SPPNet(**net_config) 49 | 50 | dataset = data_config['dataset'] 51 | if dataset == 'pascal': 52 | from dataset.pascal_voc import PascalVocDataset as Dataset 53 | net_config['output_channels'] = 21 54 | classes = np.arange(1, 21) 55 | elif dataset == 'cityscapes': 56 | from dataset.cityscapes import CityscapesDataset as Dataset 57 | net_config['output_channels'] = 19 58 | classes = np.arange(1, 19) 59 | else: 60 | raise NotImplementedError 61 | del data_config['dataset'] 62 | 63 | modelname = config_path.stem 64 | output_dir = Path('../model') / modelname 65 | output_dir.mkdir(exist_ok=True) 66 | log_dir = Path('../logs') / modelname 67 | log_dir.mkdir(exist_ok=True) 68 | 69 | logger = debug_logger(log_dir) 70 | logger.debug(config) 71 | logger.info(f'Device: {device}') 72 | logger.info(f'Max Epoch: {max_epoch}') 73 | 74 | # Loss 75 | loss_fn = MultiClassCriterion(**loss_config).to(device) 76 | params = model.parameters() 77 | optimizer, scheduler = create_optimizer(params, **opt_config) 78 | 79 | # history 80 | if resume: 81 | with open(log_dir.joinpath('history.pkl'), 'rb') as f: 82 | history_dict = pickle.load(f) 83 | best_metrics = history_dict['best_metrics'] 84 | loss_history = history_dict['loss'] 85 | iou_history = history_dict['iou'] 86 | start_epoch = len(iou_history) 87 | for _ in range(start_epoch): 88 | scheduler.step() 89 | else: 90 | start_epoch = 0 91 | best_metrics = 0 92 | loss_history = [] 93 | iou_history = [] 94 | 95 | # Dataset 96 | affine_augmenter = albu.Compose([albu.HorizontalFlip(p=.5), 97 | # Rotate(5, p=.5) 98 | ]) 99 | # image_augmenter = albu.Compose([albu.GaussNoise(p=.5), 100 | # albu.RandomBrightnessContrast(p=.5)]) 101 | image_augmenter = None 102 | train_dataset = Dataset(affine_augmenter=affine_augmenter, image_augmenter=image_augmenter, 103 | net_type=net_type, **data_config) 104 | valid_dataset = Dataset(split='valid', net_type=net_type, **data_config) 105 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, 106 | pin_memory=True, drop_last=True) 107 | valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) 108 | 109 | # To device 110 | model = model.to(device) 111 | 112 | # Pretrained model 113 | if pretrained_path: 114 | logger.info(f'Resume from {pretrained_path}') 115 | param = torch.load(pretrained_path) 116 | model.load_state_dict(param) 117 | del param 118 | 119 | # fp16 120 | if fp16: 121 | from apex import fp16_utils 122 | model = fp16_utils.BN_convert_float(model.half()) 123 | optimizer = fp16_utils.FP16_Optimizer(optimizer, verbose=False, dynamic_loss_scale=True) 124 | logger.info('Apply fp16') 125 | 126 | # Restore model 127 | if resume: 128 | model_path = output_dir.joinpath(f'model_tmp.pth') 129 | logger.info(f'Resume from {model_path}') 130 | param = torch.load(model_path) 131 | model.load_state_dict(param) 132 | del param 133 | opt_path = output_dir.joinpath(f'opt_tmp.pth') 134 | param = torch.load(opt_path) 135 | optimizer.load_state_dict(param) 136 | del param 137 | 138 | # Train 139 | for i_epoch in range(start_epoch, max_epoch): 140 | logger.info(f'Epoch: {i_epoch}') 141 | logger.info(f'Learning rate: {optimizer.param_groups[0]["lr"]}') 142 | 143 | train_losses = [] 144 | train_ious = [] 145 | model.train() 146 | with tqdm(train_loader) as _tqdm: 147 | for batched in _tqdm: 148 | images, labels, _ = batched 149 | if fp16: 150 | images = images.half() 151 | images, labels = images.to(device), labels.to(device) 152 | optimizer.zero_grad() 153 | preds = model(images) 154 | if net_type == 'deeplab': 155 | preds = F.interpolate(preds, size=labels.shape[1:], mode='bilinear', align_corners=True) 156 | if fp16: 157 | loss = loss_fn(preds.float(), labels) 158 | else: 159 | loss = loss_fn(preds, labels) 160 | 161 | preds_np = preds.detach().cpu().numpy() 162 | labels_np = labels.detach().cpu().numpy() 163 | iou = compute_iou_batch(np.argmax(preds_np, axis=1), labels_np, classes) 164 | 165 | _tqdm.set_postfix(OrderedDict(seg_loss=f'{loss.item():.5f}', iou=f'{iou:.3f}')) 166 | train_losses.append(loss.item()) 167 | train_ious.append(iou) 168 | 169 | if fp16: 170 | optimizer.backward(loss) 171 | else: 172 | loss.backward() 173 | optimizer.step() 174 | 175 | scheduler.step() 176 | 177 | train_loss = np.mean(train_losses) 178 | train_iou = np.nanmean(train_ious) 179 | logger.info(f'train loss: {train_loss}') 180 | logger.info(f'train iou: {train_iou}') 181 | 182 | torch.save(model.state_dict(), output_dir.joinpath('model_tmp.pth')) 183 | torch.save(optimizer.state_dict(), output_dir.joinpath('opt_tmp.pth')) 184 | 185 | if (i_epoch + 1) % 1 == 0: 186 | valid_losses = [] 187 | valid_ious = [] 188 | model.eval() 189 | with torch.no_grad(): 190 | with tqdm(valid_loader) as _tqdm: 191 | for batched in _tqdm: 192 | images, labels, _ = batched 193 | if fp16: 194 | images = images.half() 195 | images, labels = images.to(device), labels.to(device) 196 | preds = model.tta(images, net_type=net_type) 197 | if fp16: 198 | loss = loss_fn(preds.float(), labels) 199 | else: 200 | loss = loss_fn(preds, labels) 201 | 202 | preds_np = preds.detach().cpu().numpy() 203 | labels_np = labels.detach().cpu().numpy() 204 | iou = compute_iou_batch(np.argmax(preds_np, axis=1), labels_np, classes) 205 | 206 | _tqdm.set_postfix(OrderedDict(seg_loss=f'{loss.item():.5f}', iou=f'{iou:.3f}')) 207 | valid_losses.append(loss.item()) 208 | valid_ious.append(iou) 209 | 210 | valid_loss = np.mean(valid_losses) 211 | valid_iou = np.mean(valid_ious) 212 | logger.info(f'valid seg loss: {valid_loss}') 213 | logger.info(f'valid iou: {valid_iou}') 214 | 215 | if best_metrics < valid_iou: 216 | best_metrics = valid_iou 217 | logger.info('Best Model!') 218 | torch.save(model.state_dict(), output_dir.joinpath('model.pth')) 219 | torch.save(optimizer.state_dict(), output_dir.joinpath('opt.pth')) 220 | else: 221 | valid_loss = None 222 | valid_iou = None 223 | 224 | loss_history.append([train_loss, valid_loss]) 225 | iou_history.append([train_iou, valid_iou]) 226 | history_ploter(loss_history, log_dir.joinpath('loss.png')) 227 | history_ploter(iou_history, log_dir.joinpath('iou.png')) 228 | 229 | history_dict = {'loss': loss_history, 230 | 'iou': iou_history, 231 | 'best_metrics': best_metrics} 232 | with open(log_dir.joinpath('history.pkl'), 'wb') as f: 233 | pickle.dump(history_dict, f) 234 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyoki-mtl/pytorch-segmentation/ddf7ce97114afa0f18902ba4a8f6dd6db581bcec/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/custum_aug.py: -------------------------------------------------------------------------------- 1 | import random 2 | import cv2 3 | import numpy as np 4 | from albumentations.core.transforms_interface import to_tuple, ImageOnlyTransform, DualTransform 5 | 6 | 7 | def apply_motion_blur(image, count): 8 | """ 9 | https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library 10 | """ 11 | image_t = image.copy() 12 | imshape = image_t.shape 13 | size = 15 14 | kernel_motion_blur = np.zeros((size, size)) 15 | kernel_motion_blur[int((size - 1) / 2), :] = np.ones(size) 16 | kernel_motion_blur = kernel_motion_blur / size 17 | i = imshape[1] * 3 // 4 - 10 * count 18 | while i <= imshape[1]: 19 | image_t[:, i:, :] = cv2.filter2D(image_t[:, i:, :], -1, kernel_motion_blur) 20 | image_t[:, :imshape[1] - i, :] = cv2.filter2D(image_t[:, :imshape[1] - i, :], -1, kernel_motion_blur) 21 | i += imshape[1] // 25 - count 22 | count += 1 23 | color_image = image_t 24 | return color_image 25 | 26 | 27 | def rotate(img, angle, interpolation, border_mode, border_value=None): 28 | height, width = img.shape[:2] 29 | matrix = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1.0) 30 | img = cv2.warpAffine(img, matrix, (width, height), 31 | flags=interpolation, borderMode=border_mode, borderValue=border_value) 32 | return img 33 | 34 | 35 | class AddSpeed(ImageOnlyTransform): 36 | def __init__(self, speed_coef=-1, p=.5): 37 | super().__init__(p) 38 | assert speed_coef == -1 or 0 <= speed_coef <= 1 39 | self.speed_coef = speed_coef 40 | 41 | def apply(self, img, count=7, **params): 42 | return apply_motion_blur(img, count) 43 | 44 | def get_params(self): 45 | if self.speed_coef == -1: 46 | return {'count': int(15 * random.uniform(0, 1))} 47 | else: 48 | return {'count': int(15 * self.speed_coef)} 49 | 50 | 51 | class Rotate(DualTransform): 52 | def __init__(self, limit=90, interpolation=cv2.INTER_LINEAR, 53 | border_mode=cv2.BORDER_REFLECT_101, border_value=255, always_apply=False, p=.5): 54 | super().__init__(always_apply, p) 55 | self.limit = to_tuple(limit) 56 | self.interpolation = interpolation 57 | self.border_mode = border_mode 58 | self.border_value = border_value 59 | 60 | def apply(self, img, angle=0, **params): 61 | return rotate(img, angle, interpolation=self.interpolation, border_mode=self.border_mode) 62 | 63 | def apply_to_mask(self, img, angle=0, **params): 64 | return rotate(img, angle, interpolation=cv2.INTER_NEAREST, 65 | border_mode=cv2.BORDER_CONSTANT, border_value=self.border_value) 66 | 67 | def get_params(self): 68 | return {'angle': random.uniform(self.limit[0], self.limit[1])} 69 | 70 | 71 | class PadIfNeededRightBottom(DualTransform): 72 | def __init__(self, min_height=769, min_width=769, border_mode=cv2.BORDER_CONSTANT, 73 | value=0, ignore_index=255, always_apply=False, p=1.0): 74 | super().__init__(always_apply, p) 75 | self.min_height = min_height 76 | self.min_width = min_width 77 | self.border_mode = border_mode 78 | self.value = value 79 | self.ignore_index = ignore_index 80 | 81 | def apply(self, img, **params): 82 | img_height, img_width = img.shape[:2] 83 | pad_height = max(0, self.min_height-img_height) 84 | pad_width = max(0, self.min_width-img_width) 85 | return np.pad(img, ((0, pad_height), (0, pad_width), (0, 0)), 'constant', constant_values=self.value) 86 | 87 | def apply_to_mask(self, img, **params): 88 | img_height, img_width = img.shape[:2] 89 | pad_height = max(0, self.min_height-img_height) 90 | pad_width = max(0, self.min_width-img_width) 91 | return np.pad(img, ((0, pad_height), (0, pad_width)), 'constant', constant_values=self.ignore_index) 92 | -------------------------------------------------------------------------------- /src/utils/functional.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): 5 | if b is not None: 6 | a, b = np.broadcast_arrays(a, b) 7 | if np.any(b == 0): 8 | a = a + 0. # promote to at least float 9 | a[b == 0] = -np.inf 10 | 11 | a_max = np.amax(a, axis=axis, keepdims=True) 12 | 13 | if a_max.ndim > 0: 14 | a_max[~np.isfinite(a_max)] = 0 15 | elif not np.isfinite(a_max): 16 | a_max = 0 17 | 18 | if b is not None: 19 | b = np.asarray(b) 20 | tmp = b * np.exp(a - a_max) 21 | else: 22 | tmp = np.exp(a - a_max) 23 | 24 | # suppress warnings about log of zero 25 | with np.errstate(divide='ignore'): 26 | s = np.sum(tmp, axis=axis, keepdims=keepdims) 27 | if return_sign: 28 | sgn = np.sign(s) 29 | s *= sgn # /= makes more sense but we need zero -> zero 30 | out = np.log(s) 31 | 32 | if not keepdims: 33 | a_max = np.squeeze(a_max, axis=axis) 34 | out += a_max 35 | 36 | if return_sign: 37 | return out, sgn 38 | else: 39 | return out 40 | 41 | 42 | def softmax(x, axis=None): 43 | return np.exp(x - logsumexp(x, axis=axis, keepdims=True)) -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import warnings 3 | warnings.filterwarnings('ignore', category=RuntimeWarning) 4 | 5 | 6 | def compute_ious(pred, label, classes, ignore_index=255, only_present=True): 7 | pred[label == ignore_index] = 0 8 | ious = [] 9 | for c in classes: 10 | label_c = label == c 11 | if only_present and np.sum(label_c) == 0: 12 | ious.append(np.nan) 13 | continue 14 | pred_c = pred == c 15 | intersection = np.logical_and(pred_c, label_c).sum() 16 | union = np.logical_or(pred_c, label_c).sum() 17 | if union != 0: 18 | ious.append(intersection / union) 19 | return ious if ious else [1] 20 | 21 | 22 | def compute_iou_batch(preds, labels, classes=None): 23 | iou = np.nanmean([np.nanmean(compute_ious(pred, label, classes)) for pred, label in zip(preds, labels)]) 24 | return iou 25 | 26 | 27 | def iou_analyzer(preds, labels, tods): 28 | mIoU = np.nanmean([np.nanmean(compute_ious(pred, label, [1, 2, 3, 4])) for pred, label in zip(preds, labels)]) 29 | print(f'Valid mIoU: {mIoU:.3f}\n') 30 | 31 | class_names = ['car', 'person', 'signal', 'road'] 32 | tod_names = ['morning', 'day', 'night'] 33 | iou_dict = {tod_name: dict(zip(class_names, [[] for _ in range(len(class_names))])) for tod_name in tod_names} 34 | for pred, label, tod in zip(preds, labels, tods): 35 | iou_per_class = compute_ious(pred, label, [1, 2, 3, 4]) 36 | for iou, class_name in zip(iou_per_class, class_names): 37 | iou_dict[tod][class_name].append(iou) 38 | 39 | for tod_name in tod_names: 40 | print(f'\n---{tod_name}---') 41 | for k, v in iou_dict[tod_name].items(): 42 | print(f'{k}: {np.nanmean(v):.3f}') 43 | 44 | print('\n---ALL---') 45 | for class_name in class_names: 46 | ious = [] 47 | for tod_name in tod_names: 48 | ious += iou_dict[tod_name][class_name] 49 | print(f'{class_name}: {np.nanmean(ious):.3f}') 50 | -------------------------------------------------------------------------------- /src/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from .scheduler import CosineWithRestarts 3 | 4 | 5 | def create_optimizer(params, mode='adam', base_lr=1e-3, t_max=10): 6 | if mode == 'adam': 7 | optimizer = optim.Adam(params, base_lr) 8 | elif mode == 'sgd': 9 | optimizer = optim.SGD(params, base_lr, momentum=0.9, weight_decay=4e-5) 10 | else: 11 | raise NotImplementedError(mode) 12 | 13 | scheduler = CosineWithRestarts(optimizer, t_max) 14 | 15 | return optimizer, scheduler 16 | -------------------------------------------------------------------------------- /src/utils/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | def minmax_normalize(img, norm_range=(0, 1), orig_range=(0, 255)): 6 | # range(0, 1) 7 | norm_img = (img - orig_range[0]) / (orig_range[1] - orig_range[0]) 8 | # range(min_value, max_value) 9 | norm_img = norm_img * (norm_range[1] - norm_range[0]) + norm_range[0] 10 | return norm_img 11 | 12 | 13 | def meanstd_normalize(img, mean, std): 14 | mean = np.asarray(mean) 15 | std = np.asarray(std) 16 | norm_img = (img - mean) / std 17 | return norm_img 18 | 19 | 20 | def padding(img, pad, constant_values=0): 21 | pad_img = np.pad(img, pad, 'constant', constant_values=constant_values) 22 | return pad_img 23 | 24 | 25 | def clahe(img, clip=2, grid=8): 26 | img_yuv = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) 27 | _clahe = cv2.createCLAHE(clipLimit=clip, tileGridSize=(grid, grid)) 28 | img_yuv[:, :, 0] = _clahe.apply(img_yuv[:, :, 0]) 29 | img_equ = cv2.cvtColor(img_yuv, cv2.COLOR_LAB2BGR) 30 | return img_equ 31 | -------------------------------------------------------------------------------- /src/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | """ 4 | https://github.com/allenai/allennlp/pull/1647/files 5 | """ 6 | 7 | 8 | class CosineWithRestarts(torch.optim.lr_scheduler._LRScheduler): # pylint: disable=protected-access 9 | """ 10 | Cosine annealing with restarts. 11 | This is decribed in the paper https://arxiv.org/abs/1608.03983. 12 | Parameters 13 | ---------- 14 | optimizer : ``torch.optim.Optimizer`` 15 | t_max : ``int`` 16 | The maximum number of iterations within the first cycle. 17 | eta_min : ``float``, optional (default=0) 18 | The minimum learning rate. 19 | last_epoch : ``int``, optional (default=-1) 20 | The index of the last epoch. This is used when restarting. 21 | factor : ``float``, optional (default=1) 22 | The factor by which the cycle length (``T_max``) increases after each restart. 23 | """ 24 | def __init__(self, 25 | optimizer: torch.optim.Optimizer, 26 | t_max: int, 27 | eta_min: float = 0., 28 | last_epoch: int = -1, 29 | factor: float = 1.) -> None: 30 | assert t_max > 0 31 | assert eta_min >= 0 32 | self.t_max = t_max 33 | self.eta_min = eta_min 34 | self.factor = factor 35 | self._last_restart: int = 0 36 | self._cycle_counter: int = 0 37 | self._cycle_factor: float = 1. 38 | self._updated_cycle_len: int = t_max 39 | self._initialized: bool = False 40 | super(CosineWithRestarts, self).__init__(optimizer, last_epoch) 41 | def get_lr(self): 42 | """Get updated learning rate.""" 43 | # HACK: We need to check if this is the first time ``self.get_lr()`` was called, 44 | # since ``torch.optim.lr_scheduler._LRScheduler`` will call ``self.get_lr()`` 45 | # when first initialized, but the learning rate should remain unchanged 46 | # for the first epoch. 47 | if not self._initialized: 48 | self._initialized = True 49 | return self.base_lrs 50 | step = self.last_epoch + 1 51 | self._cycle_counter = step - self._last_restart 52 | lrs = [ 53 | self.eta_min + ((lr - self.eta_min) / 2) * ( 54 | np.cos( 55 | np.pi * 56 | (self._cycle_counter % self._updated_cycle_len) / 57 | self._updated_cycle_len 58 | ) + 1 59 | ) 60 | for lr in self.base_lrs 61 | ] 62 | if self._cycle_counter % self._updated_cycle_len == 0: 63 | # Adjust the cycle length. 64 | self._cycle_factor *= self.factor 65 | self._cycle_counter = 0 66 | self._updated_cycle_len = int(self._cycle_factor * self.t_max) 67 | self._last_restart = step 68 | return lrs 69 | -------------------------------------------------------------------------------- /src/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | n_classes = 5 6 | valid_colors = [(0, 0, 0), 7 | (0, 0, 255), 8 | (255, 0, 0), 9 | (255, 255, 0), 10 | (69, 47, 142), 11 | ] 12 | class_map = dict(zip(valid_colors, range(n_classes))) 13 | own_mask = np.array(Image.open('../preprocess/own_mask010.png')).astype(bool) 14 | 15 | 16 | def encode_mask(color_mask): 17 | valid_mask = np.zeros((color_mask.shape[0], color_mask.shape[1]), dtype=np.uint8) 18 | colors = valid_colors[1:] 19 | 20 | for c in colors: 21 | tmp_index = color_mask == c 22 | index = np.einsum('ij,ij,ij->ij', tmp_index[:, :, 0], tmp_index[:, :, 1], tmp_index[:, :, 2]) 23 | valid_mask[index] = class_map[c] 24 | valid_mask[own_mask] = 0 25 | return valid_mask 26 | 27 | 28 | def label_colormap(n=256): 29 | def bitget(byteval, idx): 30 | return (byteval & (1 << idx)) != 0 31 | 32 | cmap = np.zeros((n, 3)) 33 | for i in range(0, n): 34 | id = i 35 | r, g, b = 0, 0, 0 36 | for j in range(0, 8): 37 | r = np.bitwise_or(r, (bitget(id, 0) << 7 - j)) 38 | g = np.bitwise_or(g, (bitget(id, 1) << 7 - j)) 39 | b = np.bitwise_or(b, (bitget(id, 2) << 7 - j)) 40 | id = (id >> 3) 41 | cmap[i, 0] = r 42 | cmap[i, 1] = g 43 | cmap[i, 2] = b 44 | cmap = cmap.astype(np.float32) / 255 45 | return cmap 46 | 47 | 48 | def label2rgb(lbl, img=None, n_labels=n_classes, ignore_index=255, alpha=0.3, to_gray=False): 49 | cmap = label_colormap(n_labels) 50 | cmap = (cmap * 255).astype(np.uint8) 51 | 52 | lbl_viz = cmap[lbl] 53 | lbl_viz[lbl == ignore_index] = (0, 0, 0) # unlabeled 54 | 55 | if img is not None: 56 | if to_gray: 57 | img = Image.fromarray(img).convert('LA') 58 | img = np.asarray(img.convert('RGB')) 59 | lbl_viz = alpha * lbl_viz + (1 - alpha) * img 60 | lbl_viz = lbl_viz.astype(np.uint8) 61 | 62 | return lbl_viz 63 | -------------------------------------------------------------------------------- /tf_model/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyoki-mtl/pytorch-segmentation/ddf7ce97114afa0f18902ba4a8f6dd6db581bcec/tf_model/.gitkeep --------------------------------------------------------------------------------