├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── Thumbnail.png └── exam_result │ ├── ground_truth1.png │ ├── ground_truth2.png │ ├── input1.png │ ├── input2.png │ ├── result1.png │ └── result2.png ├── requirements.txt ├── segmentation ├── __init__.py ├── data_loader │ ├── __init__.py │ ├── segmentation_dataset.py │ └── transform.py ├── encoders │ ├── __init__.py │ ├── mobilenet.py │ ├── resnet.py │ ├── squeeze_extractor.py │ └── vgg.py ├── example │ └── dataset │ │ └── cityspaces │ │ ├── images │ │ ├── test │ │ │ ├── lindau_000000_000019.png │ │ │ ├── lindau_000001_000019.png │ │ │ ├── lindau_000002_000019.png │ │ │ ├── lindau_000003_000019.png │ │ │ ├── lindau_000004_000019.png │ │ │ └── lindau_000005_000019.png │ │ └── train │ │ │ ├── jena_000000_000019.png │ │ │ ├── jena_000001_000019.png │ │ │ ├── jena_000002_000019.png │ │ │ ├── jena_000003_000019.png │ │ │ ├── jena_000004_000019.png │ │ │ ├── jena_000005_000019.png │ │ │ ├── jena_000006_000019.png │ │ │ ├── jena_000007_000019.png │ │ │ ├── jena_000008_000019.png │ │ │ ├── jena_000009_000019.png │ │ │ └── jena_000010_000019.png │ │ └── labeled │ │ ├── test │ │ ├── lindau_000000_000019.png │ │ ├── lindau_000001_000019.png │ │ ├── lindau_000002_000019.png │ │ ├── lindau_000003_000019.png │ │ ├── lindau_000004_000019.png │ │ └── lindau_000005_000019.png │ │ └── train │ │ ├── jena_000000_000019.png │ │ ├── jena_000001_000019.png │ │ ├── jena_000002_000019.png │ │ ├── jena_000003_000019.png │ │ ├── jena_000004_000019.png │ │ ├── jena_000005_000019.png │ │ ├── jena_000006_000019.png │ │ ├── jena_000007_000019.png │ │ ├── jena_000008_000019.png │ │ ├── jena_000009_000019.png │ │ └── jena_000010_000019.png ├── models │ ├── __init__.py │ ├── all_models.py │ ├── fcn16.py │ ├── fcn32.py │ ├── fcn8.py │ ├── pspnet.py │ └── unet.py ├── predict.py ├── test │ ├── __init__.py │ ├── dataset │ │ └── cityspaces │ │ │ ├── images │ │ │ ├── test │ │ │ │ ├── lindau_000000_000019.png │ │ │ │ ├── lindau_000001_000019.png │ │ │ │ ├── lindau_000002_000019.png │ │ │ │ ├── lindau_000003_000019.png │ │ │ │ ├── lindau_000004_000019.png │ │ │ │ └── lindau_000005_000019.png │ │ │ └── train │ │ │ │ ├── jena_000000_000019.png │ │ │ │ ├── jena_000001_000019.png │ │ │ │ ├── jena_000002_000019.png │ │ │ │ ├── jena_000003_000019.png │ │ │ │ ├── jena_000004_000019.png │ │ │ │ ├── jena_000005_000019.png │ │ │ │ ├── jena_000006_000019.png │ │ │ │ ├── jena_000007_000019.png │ │ │ │ ├── jena_000008_000019.png │ │ │ │ ├── jena_000009_000019.png │ │ │ │ └── jena_000010_000019.png │ │ │ ├── input.png │ │ │ └── labeled │ │ │ ├── test │ │ │ ├── lindau_000000_000019.png │ │ │ ├── lindau_000001_000019.png │ │ │ ├── lindau_000002_000019.png │ │ │ ├── lindau_000003_000019.png │ │ │ ├── lindau_000004_000019.png │ │ │ └── lindau_000005_000019.png │ │ │ └── train │ │ │ ├── jena_000000_000019.png │ │ │ ├── jena_000001_000019.png │ │ │ ├── jena_000002_000019.png │ │ │ ├── jena_000003_000019.png │ │ │ ├── jena_000004_000019.png │ │ │ ├── jena_000005_000019.png │ │ │ ├── jena_000006_000019.png │ │ │ ├── jena_000007_000019.png │ │ │ ├── jena_000008_000019.png │ │ │ ├── jena_000009_000019.png │ │ │ └── jena_000010_000019.png │ └── scrach.py └── trainer.py ├── setup.cfg ├── setup.py └── util ├── __init__.py ├── checkpoint.py ├── imshow.py ├── logger.py └── validation.py /.gitignore: -------------------------------------------------------------------------------- 1 | /segmentation/scrach.py 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Ian Taehoon Yoo 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include requirements.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic-Segmentation-Pytorch 2 | 3 | [![PyPI Version](https://img.shields.io/pypi/v/seg-torch.svg)](https://pypi.org/project/seg-torch) 4 | 5 | Pytorch implementation of FCN, UNet, PSPNet and various encoder models for the semantic segmentation. 6 |

7 | 8 |

9 |

10 | 11 | These are the reference implementation of the models. 12 | - FCN (Fully Convolutional Networks for Sementic Segmentation) [[Paper]](https://arxiv.org/abs/1411.4038) 13 | - UNet (Convolutional Networks for Biomedical Image Segmentation) [[Paper]](https://arxiv.org/abs/1505.04597) 14 | - PSPNet (Pyramid Scene Parsing Network) [[Paper]](https://arxiv.org/abs/1612.01105) 15 | 16 | ## Models 17 | 18 | This project supports models as follow: 19 | 20 | | model name | backbone model | decoder model | 21 | |:-------------------:|:-------------------:|:-------------------:| 22 | | fcn8_vgg11 | VGG 11 | FCN8 | 23 | | fcn8_vgg13 | VGG 13 | FCN8 | 24 | | fcn8_vgg16 | VGG 16 | FCN8 | 25 | | fcn8_vgg19 | VGG 19 | FCN8 | 26 | | fcn16_vgg11 | VGG 11 | FCN16 | 27 | | fcn16_vgg13 | VGG 13 | FCN16 | 28 | | fcn16_vgg16 | VGG 16 | FCN16 | 29 | | fcn16_vgg19 | VGG 19 | FCN16 | 30 | | fcn32_vgg11 | VGG 11 | FCN32 | 31 | | fcn32_vgg13 | VGG 13 | FCN32 | 32 | | fcn32_vgg16 | VGG 16 | FCN32 | 33 | | fcn32_vgg19 | VGG 19 | FCN32 | 34 | | fcn8_resnet18 | Resnet-18 | FCN8 | 35 | | fcn8_resnet34 | Resnet-34 | FCN8 | 36 | | fcn8_resnet50 | Resnet-50 | FCN8 | 37 | | fcn8_resnet101 | Resnet-101 | FCN8 | 38 | | fcn8_resnet152 | Resnet-152 | FCN8 | 39 | | fcn16_resnet18 | Resnet-18 | FCN16 | 40 | | fcn16_resnet34 | Resnet-34 | FCN16 | 41 | | fcn16_resnet50 | Resnet-50 | FCN16 | 42 | | fcn16_resnet101 | Resnet-101 | FCN16 | 43 | | fcn16_resnet152 | Resnet-152 | FCN16 | 44 | | fcn32_resnet18 | Resnet-18 | FCN32 | 45 | | fcn32_resnet34 | Resnet-34 | FCN32 | 46 | | fcn32_resnet50 | Resnet-50 | FCN32 | 47 | | fcn32_resnet101 | Resnet-101 | FCN32 | 48 | | fcn32_resnet152 | Resnet-152 | FCN32 | 49 | | fcn8_mobilenet_v2 | MobileNet-v2 | FCN8 | 50 | | fcn16_mobilenet_v2 | MobileNet-v2 | FCN16 | 51 | | fcn32_mobilenet_v2 | MobileNet-v2 | FCN32 | 52 | | unet | Unet | Unet | 53 | | unet_vgg11 | VGG11 | Unet | 54 | | unet_vgg13 | VGG13 | Unet | 55 | | unet_vgg16 | VGG16 | Unet | 56 | | unet_vgg19 | VGG19 | Unet | 57 | | unet_resnet18 | Resnet-18 | Unet | 58 | | unet_resnet34 | Resnet-34 | Unet | 59 | | unet_resnet50 | Resnet-50 | Unet | 60 | | unet_resnet101 | Resnet-101 | Unet | 61 | | unet_resnet152 | Resnet-152 | Unet | 62 | | unet_mobilenet_v2 | MobileNet-v2 | Unet | 63 | | pspnet_vgg11 | VGG11 | PSPNet | 64 | | pspnet_vgg13 | VGG13 | PSPNet | 65 | | pspnet_vgg16 | VGG16 | PSPNet | 66 | | pspnet_vgg19 | VGG19 | PSPNet | 67 | | pspnet_resnet18 | Resnet-18 | PSPNet | 68 | | pspnet_resnet34 | Resnet-34 | PSPNet | 69 | | pspnet_resnet50 | Resnet-50 | PSPNet | 70 | | pspnet_resnet101 | Resnet-101 | PSPNet | 71 | | pspnet_resnet152 | Resnet-152 | PSPNet | 72 | | pspnet_mobilenet_v2 | MobileNet-v2 | PSPNet | 73 | 74 | Example results of the pspnet_mobilenet_v2 model: 75 | 76 | Input Image | Ground Truth Image | Result Image | 77 | :-------------------------:|:-------------------------:|:-------------------------:| 78 | ![](docs/exam_result/input1.png) | ![](docs/exam_result/ground_truth1.png) | ![](docs/exam_result/result1.png) 79 | ![](docs/exam_result/input2.png) | ![](docs/exam_result/ground_truth2.png) | ![](docs/exam_result/result2.png) 80 | 81 | ## Getting Started 82 | 83 | ### Requirements 84 | 85 | - [pytorch](https://github.com/pytorch/pytorch) >= 1.5.0 86 | - [torchvision](https://github.com/pytorch/vision) >= 0.5.0 87 | - [TensorboardX](https://github.com/lanpa/tensorboardX) >= 0.2.0 88 | - opencv-python 89 | - [tqdm](https://github.com/tqdm/tqdm) 90 | 91 | ### Installation 92 | 93 | ```shell 94 | pip install seg-torch 95 | ``` 96 | 97 | or 98 | 99 | ```shell 100 | git clone https://github.com/IanTaehoonYoo/semantic-segmentation-pytorch/ 101 | cd semantic-segmentation-pytorch 102 | python setup.py install 103 | ``` 104 | ### Preparing the data for training 105 | 106 | In this project, the data for training is the [[Cityspaces]](https://www.cityscapes-dataset.com/). You can run this project using the sample dataset in the segmentation/test/dataset/cityspaces folder. If you want to run this project using another dataset, please refer to the dataset format as below. 107 | 108 | 1. There are two folders which are the training images folder and the groundtruth labeled images folder. 109 | 2. The training image and the labeled image must have the same file name and size. 110 | 3. The training image must be the RGB image, and the labeled image should have the class value, the range [0, n_classes]. 111 | 112 | ### Example code to use this project with python 113 | 114 | ```python 115 | from torchvision import transforms 116 | 117 | from segmentation.data_loader.segmentation_dataset import SegmentationDataset 118 | from segmentation.data_loader.transform import Rescale, ToTensor 119 | from segmentation.trainer import Trainer 120 | from segmentation.predict import * 121 | from segmentation.models import all_models 122 | from util.logger import Logger 123 | 124 | train_images = r'dataset/cityspaces/images/train' 125 | test_images = r'dataset/cityspaces/images/test' 126 | train_labled = r'dataset/cityspaces/labeled/train' 127 | test_labeled = r'dataset/cityspaces/labeled/test' 128 | 129 | if __name__ == '__main__': 130 | model_name = "fcn8_vgg16" 131 | device = 'cuda' 132 | batch_size = 4 133 | n_classes = 34 134 | num_epochs = 10 135 | image_axis_minimum_size = 200 136 | pretrained = True 137 | fixed_feature = False 138 | 139 | logger = Logger(model_name=model_name, data_name='example') 140 | 141 | ### Loader 142 | compose = transforms.Compose([ 143 | Rescale(image_axis_minimum_size), 144 | ToTensor() 145 | ]) 146 | 147 | train_datasets = SegmentationDataset(train_images, train_labled, n_classes, compose) 148 | train_loader = torch.utils.data.DataLoader(train_datasets, batch_size=batch_size, shuffle=True, drop_last=True) 149 | 150 | test_datasets = SegmentationDataset(test_images, test_labeled, n_classes, compose) 151 | test_loader = torch.utils.data.DataLoader(test_datasets, batch_size=batch_size, shuffle=True, drop_last=True) 152 | 153 | ### Model 154 | model = all_models.model_from_name[model_name](n_classes, batch_size, 155 | pretrained=pretrained, 156 | fixed_feature=fixed_feature) 157 | model.to(device) 158 | 159 | ###Load model 160 | ###please check the foloder: (.segmentation/test/runs/models) 161 | #logger.load_model(model, 'epoch_15') 162 | 163 | ### Optimizers 164 | if pretrained and fixed_feature: #fine tunning 165 | params_to_update = model.parameters() 166 | print("Params to learn:") 167 | params_to_update = [] 168 | for name, param in model.named_parameters(): 169 | if param.requires_grad == True: 170 | params_to_update.append(param) 171 | print("\t", name) 172 | optimizer = torch.optim.Adadelta(params_to_update) 173 | else: 174 | optimizer = torch.optim.Adadelta(model.parameters()) 175 | 176 | ### Train 177 | #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) 178 | trainer = Trainer(model, optimizer, logger, num_epochs, train_loader, test_loader) 179 | trainer.train() 180 | 181 | 182 | #### Writing the predict result. 183 | predict(model, r'dataset/cityspaces/input.png', 184 | r'dataset/cityspaces/output.png') 185 | ``` 186 | 187 | ### Pre-trained models (Encoder models) 188 | 189 | This project uses pre-trained models such as VGG, ResNet, and MobileNet from the torchvision library. If you want the fine-tunning model, you can change the input parameters which are 'pretrained' and 'fixed_feature' when calling a model. And then, you should set the optimizer to freeze the model like as follow. 190 | 191 | ```python 192 | model = all_models.model_from_name[model_name](n_classes, batch_size, 193 | pretrained=pretrained, 194 | fixed_feature=fixed_feature) 195 | 196 | # Optimizers 197 | if pretrained and fixed_feature: #fine-tunning 198 | params_to_update = model.parameters() 199 | print("Params to learn:") 200 | params_to_update = [] 201 | for name, param in model.named_parameters(): 202 | if param.requires_grad == True: 203 | params_to_update.append(param) 204 | print("\t", name) 205 | optimizer = torch.optim.Adadelta(params_to_update) 206 | else: 207 | optimizer = torch.optim.Adadelta(model.parameters()) 208 | 209 | ``` 210 | 211 | ### Getting the learning results on Tensorboard 212 | 213 | The Logger class is to write the result such as mean IoU, accuracy, loss, and predict labeled images. The logger class gets the model name and the data name. So, it can generate the tensorboard files automatically in the runs folder, .\segmentation\runs\ 214 | 215 | Here is example command to see the result 216 | 217 | ```python 218 | tensorboard --logdir=%project_path\segmentation\runs --host localhost 219 | ``` 220 | 221 | If you don't know about Tensorboard, please refer to [[Tensorboard]](https://www.tensorflow.org/tensorboard/get_started) 222 | 223 | ### Saving and loading the check points 224 | 225 | The trainer class can save the check point automatically depends on argument is called 'check_point_epoch_stride'. So check points will be saved for every epoch stride in the runs folder, ./segmentation/runs/models. 226 | 227 | Also, you can load the check point using the logger class. Here are example codes, please refer to as below. 228 | 229 | ```python 230 | 231 | """ 232 | Save check point. 233 | Please check the runs folder, ./segmentation/runs/models 234 | """ 235 | check_point_stride = 30 # the checkpoint is saved for every 30 epochs. 236 | 237 | #'model_name' and 'data_name' are to set a path to save the check point. 238 | # So you should set the same the Logger's arguemnts when you load the check point. 239 | logger = Logger(model_name="pspnet_mobilenet_v2", data_name='example') 240 | 241 | trainer = Trainer(model, optimizer, logger, num_epochs, 242 | train_loader, test_loader, 243 | check_point_epoch_stride=check_point_stride) 244 | 245 | ``` 246 | 247 | ```python 248 | """ 249 | Load check point. 250 | """ 251 | n_classes = 33 252 | batch_size = 4 253 | 254 | # The Logger's arguemnts should be the same as when you train the model. 255 | logger = Logger(model_name="pspnet_mobilenet_v2", data_name='example') 256 | 257 | model = all_models.model_from_name[model_name](n_classes, batch_size) 258 | logger.load_model(model, 'epoch_253') 259 | ``` 260 | 261 | ## Cite This Project 262 | If you find this code useful, please consider the following BibTeX entry. 263 | 264 | ```bash 265 | @misc{seg-pytorch, 266 | author = {Ian Yoo}, 267 | title = {{sementic-segmentation-pytorch: Pytorch implementation of FCN, UNet, PSPNet and various encoder models}}, 268 | howpublished = {\url{https://github.com/IanTaehoonYoo/semantic-segmentation-pytorch}}, 269 | year = {2020} 270 | } 271 | ``` 272 | -------------------------------------------------------------------------------- /docs/Thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/docs/Thumbnail.png -------------------------------------------------------------------------------- /docs/exam_result/ground_truth1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/docs/exam_result/ground_truth1.png -------------------------------------------------------------------------------- /docs/exam_result/ground_truth2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/docs/exam_result/ground_truth2.png -------------------------------------------------------------------------------- /docs/exam_result/input1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/docs/exam_result/input1.png -------------------------------------------------------------------------------- /docs/exam_result/input2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/docs/exam_result/input2.png -------------------------------------------------------------------------------- /docs/exam_result/result1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/docs/exam_result/result1.png -------------------------------------------------------------------------------- /docs/exam_result/result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/docs/exam_result/result2.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.5.0 2 | torchvision>=0.6.0 3 | numpy 4 | opencv-python 5 | tqdm -------------------------------------------------------------------------------- /segmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/__init__.py -------------------------------------------------------------------------------- /segmentation/data_loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/data_loader/__init__.py -------------------------------------------------------------------------------- /segmentation/data_loader/segmentation_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset class. 3 | 4 | Library: Tensowflow 2.2.0, pyTorch 1.5.1, OpenCV-Python 4.1.1.26 5 | Author: Ian Yoo 6 | Email: thyoostar@gmail.com 7 | """ 8 | from __future__ import absolute_import, print_function, division 9 | import os 10 | import numpy as np 11 | import time 12 | import torch 13 | from torch.utils.data import Dataset 14 | import cv2 15 | 16 | # Ignore warnings 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | 20 | class DataLoaderError(Exception): 21 | pass 22 | 23 | try: 24 | from tqdm import tqdm 25 | except ImportError: 26 | print("tqdm not found, disabling progress bars") 27 | 28 | def tqdm(iter): 29 | return iter 30 | 31 | TQDM_COLS = 80 32 | 33 | class SegmentationDataset(Dataset): 34 | """ Segmentation dataset""" 35 | def __init__(self, images_dir, segs_dir, n_classes, transform=None): 36 | """ 37 | input images must be matched. 38 | 39 | :param images_dir: path to the image directory 40 | :param segs_dir: path to the annotation image directory 41 | :param n_classes: a number of the classes 42 | :param transform: optional transform to be applied on an image 43 | """ 44 | super(SegmentationDataset, self).__init__() 45 | 46 | self.images_dir = images_dir 47 | self.segs_dir = segs_dir 48 | self.transform = transform 49 | self.n_classes = n_classes 50 | 51 | self.pairs_dir = self._get_image_pairs_(self.images_dir, self.segs_dir) 52 | verified = self._verify_segmentation_dataset() 53 | assert verified 54 | 55 | def __len__(self): 56 | return len(self.pairs_dir) 57 | 58 | def __getitem__(self, idx): 59 | if torch.is_tensor(idx): 60 | idx = idx.tolist() 61 | 62 | im = cv2.imread(self.pairs_dir[idx][0], flags=cv2.IMREAD_COLOR) 63 | lbl = cv2.imread(self.pairs_dir[idx][1], flags=cv2.IMREAD_GRAYSCALE) 64 | 65 | sample = {'image': im, 'labeled': lbl} 66 | 67 | if self.transform: 68 | sample = self.transform(sample) 69 | 70 | return sample 71 | def _verify_segmentation_dataset(self): 72 | try: 73 | if not len(self.pairs_dir): 74 | print("Couldn't load any data from self.images_dir: " 75 | "{0} and segmentations path: {1}" 76 | .format(self.images_dir, self.segs_dir)) 77 | return False 78 | 79 | return_value = True 80 | for im_fn, seg_fn in tqdm(self.pairs_dir, ncols=TQDM_COLS): 81 | img = cv2.imread(im_fn) 82 | seg = cv2.imread(seg_fn) 83 | # Check dimensions match 84 | if not img.shape == seg.shape: 85 | return_value = False 86 | print("The size of image {0} and its segmentation {1} " 87 | "doesn't match (possibly the files are corrupt)." 88 | .format(im_fn, seg_fn)) 89 | else: 90 | max_pixel_value = np.max(seg[:, :, 0]) 91 | if max_pixel_value >= self.n_classes: 92 | return_value = False 93 | print("The pixel values of the segmentation image {0} " 94 | "violating range [0, {1}]. " 95 | "Found maximum pixel value {2}" 96 | .format(seg_fn, str(self.n_classes - 1), max_pixel_value)) 97 | 98 | time.sleep(0.0001) 99 | if return_value: 100 | print("Dataset verified! ") 101 | else: 102 | print("Dataset not verified!") 103 | return return_value 104 | except DataLoaderError as e: 105 | print("Found error during data loading\n{0}".format(str(e))) 106 | return False 107 | 108 | def _get_image_pairs_(self, img_path1, img_path2): 109 | """ Check two images have the same name and get all the images 110 | :param img_path1: directory 111 | :param img_path2: directory 112 | :return: pair paths 113 | """ 114 | 115 | AVAILABLE_IMAGE_FORMATS = [".jpg", ".jpeg", ".png", ".bmp"] 116 | 117 | files1 = [] 118 | files2 = {} 119 | 120 | for dir_entry in os.listdir(img_path1): 121 | if os.path.isfile(os.path.join(img_path1, dir_entry)) and \ 122 | os.path.splitext(dir_entry)[1] in AVAILABLE_IMAGE_FORMATS: 123 | file_name, file_extension = os.path.splitext(dir_entry) 124 | files1.append((file_name, file_extension, 125 | os.path.join(img_path1, dir_entry))) 126 | 127 | for dir_entry in os.listdir(img_path2): 128 | if os.path.isfile(os.path.join(img_path2, dir_entry)) and \ 129 | os.path.splitext(dir_entry)[1] in AVAILABLE_IMAGE_FORMATS: 130 | file_name, file_extension = os.path.splitext(dir_entry) 131 | full_dir_entry = os.path.join(img_path2, dir_entry) 132 | if file_name in files2: 133 | raise DataLoaderError("img_path2 with filename {0}" 134 | " already exists and is ambiguous to" 135 | " resolve with path {1}." 136 | " Please remove or rename the latter." 137 | .format(file_name, full_dir_entry)) 138 | 139 | files2[file_name] = (file_extension, full_dir_entry) 140 | 141 | return_value = [] 142 | # Match two paths 143 | for image_file, _, image_full_path in files1: 144 | if image_file in files2: 145 | return_value.append((image_full_path, 146 | files2[image_file][1])) 147 | else: 148 | # Error out 149 | raise DataLoaderError("No corresponding images " 150 | "found for image {0}." 151 | .format(image_full_path)) 152 | 153 | return return_value 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /segmentation/data_loader/transform.py: -------------------------------------------------------------------------------- 1 | """ 2 | The transform method for the SegmentationDataset 3 | 4 | Library: Tensowflow 2.2.0, pyTorch 1.5.1, OpenCV-Python 4.1.1.26 5 | Author: Ian Yoo 6 | Email: thyoostar@gmail.com 7 | """ 8 | from __future__ import absolute_import, print_function, division 9 | 10 | import numpy as np 11 | import torch 12 | import cv2 13 | 14 | class Rescale(object): 15 | """Rescale the image in a sample to a given size. 16 | 17 | Args: 18 | output_size (tuple or int): Desired output size. If tuple, output is 19 | matched to output_size. If int, smaller of image edges is matched 20 | to output_size keeping aspect ratio the same. 21 | """ 22 | 23 | def __init__(self, output_size): 24 | assert isinstance(output_size, (int, tuple)) 25 | self.output_size = output_size 26 | 27 | def __call__(self, sample): 28 | image, labeled = sample['image'], sample['labeled'] 29 | 30 | h, w = image.shape[:2] 31 | if isinstance(self.output_size, int): 32 | if h > w: 33 | new_h, new_w = self.output_size * h / w, self.output_size 34 | else: 35 | new_h, new_w = self.output_size, self.output_size * w / h 36 | else: 37 | new_h, new_w = self.output_size 38 | 39 | new_h, new_w = int(new_h), int(new_w) 40 | 41 | img = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_NEAREST) 42 | lbl = cv2.resize(labeled, (new_w, new_h), interpolation=cv2.INTER_NEAREST) 43 | 44 | return {'image': img, 'labeled': lbl} 45 | 46 | class RandomHorizontalFlip(torch.nn.Module): 47 | """Horizontally flip the given image randomly with a given probability. 48 | The image can be a PIL Image or a torch Tensor, in which case it is expected 49 | to have [..., H, W] shape, where ... means an arbitrary number of leading 50 | dimensions 51 | 52 | Args: 53 | p (float): probability of the image being flipped. Default value is 0.5 54 | """ 55 | 56 | def __init__(self, p=0.5): 57 | super().__init__() 58 | self.p = p 59 | 60 | def __call__(self, sample): 61 | image, labeled = sample['image'], sample['labeled'] 62 | 63 | if torch.rand(1) < self.p: 64 | image = cv2.flip(image, 1) 65 | labeled = cv2.flip(labeled, 1) 66 | 67 | return {'image': image, 'labeled': labeled} 68 | 69 | def __repr__(self): 70 | return self.__class__.__name__ + '(p={})'.format(self.p) 71 | 72 | class MakeSegmentationArray(object): 73 | """Make segmentation array from the annotation image""" 74 | 75 | def __init__(self, n_classes): 76 | assert isinstance(n_classes, int) 77 | 78 | self.n_classes = n_classes 79 | 80 | def __call__(self, sample): 81 | annotation = sample['annotation'] 82 | assert annotation.dtype != int 83 | 84 | h, w = annotation.shape[:2] 85 | 86 | seg_labels = np.zeros((self.n_classes, h, w), dtype=annotation.dtype) 87 | 88 | for label in range(self.n_classes): 89 | seg_labels[label, :, :] = (annotation == label) 90 | 91 | return {'image': sample['image'], 'annotation': seg_labels} 92 | 93 | class ToTensor(object): 94 | """Convert ndarrays in sample to Tensors.""" 95 | 96 | def __call__(self, sample): 97 | image, lbl = sample['image'], sample['labeled'] 98 | 99 | # swap color axis because 100 | # numpy image: H x W x C 101 | # torch image: C X H X W 102 | image = image.transpose((2, 0, 1)) 103 | return {'image': torch.from_numpy(image).float(), 104 | 'annotation': torch.from_numpy(lbl).long()} -------------------------------------------------------------------------------- /segmentation/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/encoders/__init__.py -------------------------------------------------------------------------------- /segmentation/encoders/mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mobilenet model Customized from Torchvision. 3 | 4 | Library: Tensowflow 2.2.0, pyTorch 1.5.1 5 | Author: Ian Yoo 6 | Email: thyoostar@gmail.com 7 | """ 8 | from __future__ import absolute_import, division 9 | from .squeeze_extractor import * 10 | from torch import nn 11 | 12 | class _Mobilenet(SqueezeExtractor): 13 | def __init__(self, model, features, fixed_feature=True): 14 | layer = [] 15 | layers = [] 16 | self.zip_factor = [2, 4, 7, 11, 14, 17, 18] 17 | layer, layers = self._get_layers(features, layer, layers, 0) 18 | layers = nn.ModuleList(layers) 19 | super(_Mobilenet, self).__init__(model, layers, fixed_feature) 20 | 21 | def _get_layers(self, features, layer, layers, zip_cnt): 22 | from torchvision.models.mobilenet import InvertedResidual, ConvBNReLU 23 | 24 | for feature in features.children(): 25 | if isinstance(feature, nn.Sequential) or\ 26 | isinstance(feature, ConvBNReLU): 27 | layer, layers = self._get_layers(feature, layer, layers, zip_cnt) 28 | if zip_cnt == 0 and isinstance(feature, ConvBNReLU): 29 | layers += [nn.Sequential(*layer)] 30 | layer.clear() 31 | if isinstance(feature, InvertedResidual): 32 | layer, layers = self._get_layers(feature, layer, layers, zip_cnt) 33 | zip_cnt += 1 34 | if zip_cnt in self.zip_factor: 35 | layers += [nn.Sequential(*layer)] 36 | layer.clear() 37 | 38 | if len(list(feature.children())) == 0: 39 | layer += [feature] 40 | 41 | return layer, layers 42 | def get_copy_feature_info(self): 43 | lst_copy_feature_info = [] 44 | channel = 0 45 | for i in range(len(self.features)): 46 | feature = self.features[i] 47 | if isinstance(feature, nn.MaxPool2d): 48 | lst_copy_feature_info.append(CopyFeatureInfo(i, channel)) 49 | for idx, m in enumerate(feature.modules()): 50 | if isinstance(m, nn.Conv2d) and m.stride == (2, 2): 51 | channel = self._get_last_conv2d_out_channels(feature) 52 | lst_copy_feature_info.append(CopyFeatureInfo(i, channel)) 53 | break 54 | 55 | return lst_copy_feature_info 56 | 57 | 58 | def mobilenet(pretrained=False, fixed_feature=True): 59 | """ Mobile-net V2 model from torchvision's resnet model. 60 | 61 | :param pretrained: if true, return a model pretrained on ImageNet 62 | :param fixed_feature: if true and pretrained is true, model features are fixed while training. 63 | """ 64 | from torchvision.models.mobilenet import mobilenet_v2 65 | model = mobilenet_v2(pretrained) 66 | features = model.features[:-1] 67 | 68 | ff = True if pretrained and fixed_feature else False 69 | return _Mobilenet(model, features, ff) 70 | -------------------------------------------------------------------------------- /segmentation/encoders/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Resnet model Customized from Torchvision. 3 | 4 | Library: Tensowflow 2.2.0, pyTorch 1.5.1 5 | Author: Ian Yoo 6 | Email: thyoostar@gmail.com 7 | """ 8 | from __future__ import absolute_import, division 9 | from .squeeze_extractor import * 10 | 11 | 12 | class _ResNet(SqueezeExtractor): 13 | def __init__(self, model, fixed_feature=True): 14 | features = nn.Sequential( 15 | model.conv1, 16 | model.bn1, 17 | model.relu, 18 | model.maxpool, 19 | model.layer1, 20 | model.layer2, 21 | model.layer3, 22 | model.layer4 23 | ) 24 | super(_ResNet, self).__init__(model, features, fixed_feature) 25 | 26 | def get_copy_feature_info(self): 27 | lst_copy_feature_info = [] 28 | channel = 0 29 | for i in range(len(self.features)): 30 | feature = self.features[i] 31 | if isinstance(feature, nn.MaxPool2d): 32 | lst_copy_feature_info.append(CopyFeatureInfo(i, channel)) 33 | for idx, m in enumerate(feature.modules()): 34 | if isinstance(m, nn.Conv2d) and m.stride == (2, 2): 35 | channel = self._get_last_conv2d_out_channels(feature) 36 | lst_copy_feature_info.append(CopyFeatureInfo(i, channel)) 37 | break 38 | 39 | return lst_copy_feature_info 40 | 41 | 42 | 43 | def resnet18(pretrained=False, fixed_feature=True): 44 | """ "ResNet-18 model from torchvision's resnet model. 45 | 46 | :param pretrained: if true, return a model pretrained on ImageNet 47 | :param fixed_feature: if true and pretrained is true, model features are fixed while training. 48 | """ 49 | from torchvision.models.resnet import resnet18 50 | model = resnet18(pretrained) 51 | 52 | ff = True if pretrained and fixed_feature else False 53 | return _ResNet(model, ff) 54 | 55 | def resnet34(pretrained=False, fixed_feature=True): 56 | """ "ResNet-34 model from torchvision's resnet model. 57 | 58 | :param pretrained: if true, return a model pretrained on ImageNet 59 | :param fixed_feature: if true and pretrained is true, model features are fixed while training. 60 | """ 61 | from torchvision.models.resnet import resnet34 62 | model = resnet34(pretrained) 63 | 64 | ff = True if pretrained and fixed_feature else False 65 | return _ResNet(model, ff) 66 | 67 | def resnet50(pretrained=False, fixed_feature=True): 68 | """ "ResNet-50 model from torchvision's resnet model. 69 | 70 | :param pretrained: if true, return a model pretrained on ImageNet 71 | :param fixed_feature: if true and pretrained is true, model features are fixed while training. 72 | """ 73 | from torchvision.models.resnet import resnet50 74 | model = resnet50(pretrained) 75 | 76 | ff = True if pretrained and fixed_feature else False 77 | return _ResNet(model, ff) 78 | 79 | def resnet101(pretrained=False, fixed_feature=True): 80 | """ "ResNet-101 model from torchvision's resnet model. 81 | 82 | :param pretrained: if true, return a model pretrained on ImageNet 83 | :param fixed_feature: if true and pretrained is true, model features are fixed while training. 84 | """ 85 | from torchvision.models.resnet import resnet101 86 | model = resnet101(pretrained) 87 | 88 | ff = True if pretrained and fixed_feature else False 89 | return _ResNet(model, ff) 90 | 91 | def resnet152(pretrained=False, fixed_feature=True): 92 | """ "ResNet-152 model from torchvision's resnet model. 93 | 94 | :param pretrained: if true, return a model pretrained on ImageNet 95 | :param fixed_feature: if true and pretrained is true, model features are fixed while training. 96 | """ 97 | from torchvision.models.resnet import resnet152 98 | model = resnet152(pretrained) 99 | 100 | ff = True if pretrained and fixed_feature else False 101 | return _ResNet(model, ff) 102 | 103 | -------------------------------------------------------------------------------- /segmentation/encoders/squeeze_extractor.py: -------------------------------------------------------------------------------- 1 | """ 2 | base class for pre-trained model. 3 | 4 | Library: Tensowflow 2.2.0, pyTorch 1.5.1 5 | Author: Ian Yoo 6 | Email: thyoostar@gmail.com 7 | """ 8 | from __future__ import absolute_import, division 9 | from torch import nn 10 | from dataclasses import dataclass 11 | 12 | @dataclass 13 | class CopyFeatureInfo: 14 | index: int 15 | out_channels: int 16 | 17 | class SqueezeExtractor(nn.Module): 18 | def __init__(self, model, features, fixed_feature=True): 19 | super(SqueezeExtractor, self).__init__() 20 | self.model = model 21 | self.features = features 22 | if fixed_feature: 23 | for param in self.features.parameters(): 24 | param.requires_grad = False 25 | 26 | def get_copy_feature_info(self): 27 | """ 28 | Get [CopyFeatureInfo] when sampling such as maxpooling or conv2d which has the 2x2 stride. 29 | :return: list. [CopyFeatureInfo] 30 | """ 31 | raise NotImplementedError() 32 | 33 | def _get_last_conv2d_out_channels(self, features): 34 | for idx, m in reversed(list(enumerate(features.modules()))): 35 | if isinstance(m, nn.Conv2d): 36 | return int(m.out_channels) 37 | assert False -------------------------------------------------------------------------------- /segmentation/encoders/vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vgg16 model Customized from Torchvision. 3 | 4 | Library: Tensowflow 2.2.0, pyTorch 1.5.1 5 | Author: Ian Yoo 6 | Email: thyoostar@gmail.com 7 | """ 8 | from __future__ import absolute_import, division 9 | from .squeeze_extractor import * 10 | from torch import nn 11 | 12 | 13 | class _VGG(SqueezeExtractor): 14 | def __init__(self, model, features, fixed_feature=True): 15 | super(_VGG, self).__init__(model, features, fixed_feature) 16 | 17 | def get_copy_feature_info(self): 18 | 19 | lst_copy_feature_info = [] 20 | for i in range(len(self.features)): 21 | if isinstance(self.features[i], nn.MaxPool2d): 22 | out_channels = self._get_last_conv2d_out_channels(self.features[:i]) 23 | lst_copy_feature_info.append(CopyFeatureInfo(i, out_channels)) 24 | return lst_copy_feature_info 25 | 26 | def vgg_11(batch_norm=True, pretrained=False, fixed_feature=True): 27 | """ VGG 11-layer model from torchvision's vgg model. 28 | 29 | :param batch_norm: train model with batch normalization 30 | :param pretrained: if true, return a model pretrained on ImageNet 31 | :param fixed_feature: if true and pretrained is true, model features are fixed while training. 32 | """ 33 | if batch_norm: 34 | from torchvision.models.vgg import vgg11_bn 35 | model = vgg11_bn(pretrained) 36 | else: 37 | from torchvision.models.vgg import vgg11 38 | model = vgg11(pretrained) 39 | 40 | ff = True if pretrained and fixed_feature else False 41 | return _VGG(model, model.features, ff) 42 | 43 | def vgg_13(batch_norm=True, pretrained=False, fixed_feature=True): 44 | """ VGG 13-layer model from torchvision's vgg model. 45 | 46 | :param batch_norm: train model with batch normalization 47 | :param pretrained: if true, return a model pretrained on ImageNet 48 | :param fixed_feature: if true and pretrained is true, model features are fixed while training. 49 | """ 50 | if batch_norm: 51 | from torchvision.models.vgg import vgg13_bn 52 | model = vgg13_bn(pretrained) 53 | else: 54 | from torchvision.models.vgg import vgg13 55 | model = vgg13(pretrained) 56 | 57 | ff = True if pretrained and fixed_feature else False 58 | return _VGG(model, model.features, ff) 59 | 60 | def vgg_16(batch_norm=True, pretrained=False, fixed_feature=True): 61 | """ VGG 16-layer model from torchvision's vgg model. 62 | 63 | :param batch_norm: train model with batch normalization 64 | :param pretrained: if true, return a model pretrained on ImageNet 65 | :param fixed_feature: if true and pretrained is true, model features are fixed while training. 66 | """ 67 | if batch_norm: 68 | from torchvision.models.vgg import vgg16_bn 69 | model = vgg16_bn(pretrained) 70 | else: 71 | from torchvision.models.vgg import vgg16 72 | model = vgg16(pretrained) 73 | 74 | ff = True if pretrained and fixed_feature else False 75 | return _VGG(model, model.features, ff) 76 | 77 | def vgg_19(batch_norm=True, pretrained=False, fixed_feature=True): 78 | """ VGG 19-layer model from torchvision's vgg model. 79 | 80 | :param batch_norm: train model with batch normalization 81 | :param pretrained: if true, return a model pretrained on ImageNet 82 | :param fixed_feature: if true and pretrained is true, model features are fixed while training. 83 | """ 84 | if batch_norm: 85 | from torchvision.models.vgg import vgg19_bn 86 | model = vgg19_bn(pretrained) 87 | else: 88 | from torchvision.models.vgg import vgg19 89 | model = vgg19(pretrained) 90 | 91 | ff = True if pretrained and fixed_feature else False 92 | return _VGG(model, model.features, ff) 93 | -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/images/test/lindau_000000_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/images/test/lindau_000000_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/images/test/lindau_000001_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/images/test/lindau_000001_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/images/test/lindau_000002_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/images/test/lindau_000002_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/images/test/lindau_000003_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/images/test/lindau_000003_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/images/test/lindau_000004_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/images/test/lindau_000004_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/images/test/lindau_000005_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/images/test/lindau_000005_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/images/train/jena_000000_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/images/train/jena_000000_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/images/train/jena_000001_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/images/train/jena_000001_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/images/train/jena_000002_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/images/train/jena_000002_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/images/train/jena_000003_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/images/train/jena_000003_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/images/train/jena_000004_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/images/train/jena_000004_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/images/train/jena_000005_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/images/train/jena_000005_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/images/train/jena_000006_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/images/train/jena_000006_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/images/train/jena_000007_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/images/train/jena_000007_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/images/train/jena_000008_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/images/train/jena_000008_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/images/train/jena_000009_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/images/train/jena_000009_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/images/train/jena_000010_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/images/train/jena_000010_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/labeled/test/lindau_000000_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/labeled/test/lindau_000000_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/labeled/test/lindau_000001_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/labeled/test/lindau_000001_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/labeled/test/lindau_000002_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/labeled/test/lindau_000002_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/labeled/test/lindau_000003_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/labeled/test/lindau_000003_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/labeled/test/lindau_000004_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/labeled/test/lindau_000004_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/labeled/test/lindau_000005_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/labeled/test/lindau_000005_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/labeled/train/jena_000000_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/labeled/train/jena_000000_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/labeled/train/jena_000001_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/labeled/train/jena_000001_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/labeled/train/jena_000002_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/labeled/train/jena_000002_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/labeled/train/jena_000003_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/labeled/train/jena_000003_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/labeled/train/jena_000004_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/labeled/train/jena_000004_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/labeled/train/jena_000005_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/labeled/train/jena_000005_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/labeled/train/jena_000006_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/labeled/train/jena_000006_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/labeled/train/jena_000007_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/labeled/train/jena_000007_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/labeled/train/jena_000008_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/labeled/train/jena_000008_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/labeled/train/jena_000009_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/labeled/train/jena_000009_000019.png -------------------------------------------------------------------------------- /segmentation/example/dataset/cityspaces/labeled/train/jena_000010_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/example/dataset/cityspaces/labeled/train/jena_000010_000019.png -------------------------------------------------------------------------------- /segmentation/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/models/__init__.py -------------------------------------------------------------------------------- /segmentation/models/all_models.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | from .fcn8 import * 4 | from .fcn16 import * 5 | from .fcn32 import * 6 | from .unet import * 7 | from .pspnet import * 8 | 9 | model_from_name = {} 10 | 11 | model_from_name["fcn8_vgg11"] = fcn8_vgg11 12 | model_from_name["fcn8_vgg13"] = fcn8_vgg13 13 | model_from_name["fcn8_vgg16"] = fcn8_vgg16 14 | model_from_name["fcn8_vgg19"] = fcn8_vgg19 15 | model_from_name["fcn16_vgg11"] = fcn16_vgg11 16 | model_from_name["fcn16_vgg13"] = fcn16_vgg13 17 | model_from_name["fcn16_vgg16"] = fcn16_vgg16 18 | model_from_name["fcn16_vgg19"] = fcn16_vgg19 19 | model_from_name["fcn32_vgg11"] = fcn32_vgg11 20 | model_from_name["fcn32_vgg13"] = fcn32_vgg13 21 | model_from_name["fcn32_vgg16"] = fcn32_vgg16 22 | model_from_name["fcn32_vgg19"] = fcn32_vgg19 23 | model_from_name["fcn8_resnet18"] = fcn8_resnet18 24 | model_from_name["fcn8_resnet34"] = fcn8_resnet34 25 | model_from_name["fcn8_resnet50"] = fcn8_resnet50 26 | model_from_name["fcn8_resnet101"] = fcn8_resnet101 27 | model_from_name["fcn8_resnet152"] = fcn8_resnet152 28 | model_from_name["fcn16_resnet18"] = fcn16_resnet18 29 | model_from_name["fcn16_resnet34"] = fcn16_resnet34 30 | model_from_name["fcn16_resnet50"] = fcn16_resnet50 31 | model_from_name["fcn16_resnet101"] = fcn16_resnet101 32 | model_from_name["fcn16_resnet152"] = fcn16_resnet152 33 | model_from_name["fcn32_resnet18"] = fcn32_resnet18 34 | model_from_name["fcn32_resnet34"] = fcn32_resnet34 35 | model_from_name["fcn32_resnet50"] = fcn32_resnet50 36 | model_from_name["fcn32_resnet101"] = fcn32_resnet101 37 | model_from_name["fcn32_resnet152"] = fcn32_resnet152 38 | model_from_name["fcn8_mobilenet_v2"] = fcn8_mobilenet_v2 39 | model_from_name["fcn16_mobilenet_v2"] = fcn16_mobilenet_v2 40 | model_from_name["fcn32_mobilenet_v2"] = fcn32_mobilenet_v2 41 | 42 | model_from_name["unet"] = unet 43 | model_from_name["unet_vgg11"] = unet_vgg11 44 | model_from_name["unet_vgg13"] = unet_vgg13 45 | model_from_name["unet_vgg16"] = unet_vgg16 46 | model_from_name["unet_vgg19"] = unet_vgg19 47 | model_from_name["unet_resnet18"] = unet_resnet18 48 | model_from_name["unet_resnet34"] = unet_resnet34 49 | model_from_name["unet_resnet50"] = unet_resnet50 50 | model_from_name["unet_resnet101"] = unet_resnet101 51 | model_from_name["unet_resnet152"] = unet_resnet152 52 | model_from_name["unet_mobilenet_v2"] = unet_mobilenet_v2 53 | 54 | model_from_name["pspnet_vgg11"] = pspnet_vgg11 55 | model_from_name["pspnet_vgg13"] = pspnet_vgg13 56 | model_from_name["pspnet_vgg16"] = pspnet_vgg16 57 | model_from_name["pspnet_vgg19"] = pspnet_vgg19 58 | model_from_name["pspnet_resnet18"] = pspnet_resnet18 59 | model_from_name["pspnet_resnet34"] = pspnet_resnet34 60 | model_from_name["pspnet_resnet50"] = pspnet_resnet50 61 | model_from_name["pspnet_resnet101"] = pspnet_resnet101 62 | model_from_name["pspnet_resnet152"] = pspnet_resnet152 63 | model_from_name["pspnet_mobilenet_v2"] = pspnet_mobilenet_v2 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /segmentation/models/fcn16.py: -------------------------------------------------------------------------------- 1 | """ 2 | FCN16 class. 3 | 4 | Library: Tensowflow 2.2.0, pyTorch 1.5.1 5 | Author: Ian Yoo 6 | Email: thyoostar@gmail.com 7 | """ 8 | from __future__ import absolute_import, division, print_function 9 | 10 | import torch 11 | from ..encoders.squeeze_extractor import * 12 | 13 | class FCN16(torch.nn.Module): 14 | 15 | def __init__(self, n_classes, pretrained_model: SqueezeExtractor): 16 | super(FCN16, self).__init__() 17 | self.pretrained_model = pretrained_model 18 | self.features = self.pretrained_model.features 19 | self.copy_feature_info = pretrained_model.get_copy_feature_info() 20 | 21 | self.score_pool4 = nn.Conv2d(self.copy_feature_info[-2].out_channels, 22 | n_classes, kernel_size=1) 23 | 24 | self.upsampling2 = nn.ConvTranspose2d(n_classes, n_classes, kernel_size=4, 25 | stride=2, bias=False) 26 | self.upsampling16 = nn.ConvTranspose2d(n_classes, n_classes, kernel_size=32, 27 | stride=16, bias=False) 28 | 29 | for m in self.features.modules(): 30 | if isinstance(m, nn.Conv2d): 31 | channels = m.out_channels 32 | 33 | self.classifier = nn.Sequential(nn.Conv2d(channels, n_classes, kernel_size=1), nn.Sigmoid()) 34 | self._initialize_weights() 35 | 36 | def _initialize_weights(self): 37 | for m in self.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 40 | if m.bias is not None: 41 | nn.init.constant_(m.bias, 0) 42 | elif isinstance(m, nn.Linear): 43 | nn.init.normal_(m.weight, 0, 0.01) 44 | nn.init.constant_(m.bias, 0) 45 | 46 | def forward(self, x): 47 | last_feature_index = self.copy_feature_info[-2].index 48 | 49 | o = x 50 | for i in range(len(self.features)): 51 | o = self.features[i](o) 52 | if i == last_feature_index: 53 | pool4 = o 54 | 55 | o = self.classifier(o) 56 | o = self.upsampling2(o) 57 | 58 | o2 = self.score_pool4(pool4) 59 | o = o[:, :, 1:1 + o2.size()[2], 1:1 + o2.size()[3]] 60 | o = o + o2 61 | 62 | o = self.upsampling16(o) 63 | cx = int((o.shape[3] - x.shape[3]) / 2) 64 | cy = int((o.shape[2] - x.shape[2]) / 2) 65 | o = o[:, :, cy:cy + x.shape[2], cx:cx + x.shape[3]] 66 | 67 | return o 68 | 69 | from ..encoders.vgg import * 70 | from ..encoders.resnet import * 71 | from ..encoders.mobilenet import * 72 | 73 | def fcn16_vgg11(n_classes, batch_size, pretrained=False, fixed_feature=True): 74 | batch_norm = False if batch_size == 1 else True 75 | vgg = vgg_11(batch_norm, pretrained, fixed_feature) 76 | return FCN16(n_classes, vgg) 77 | def fcn16_vgg13(n_classes, batch_size, pretrained=False, fixed_feature=True): 78 | batch_norm = False if batch_size == 1 else True 79 | vgg = vgg_13(batch_norm, pretrained, fixed_feature) 80 | return FCN16(n_classes, vgg) 81 | def fcn16_vgg16(n_classes, batch_size, pretrained=False, fixed_feature=True): 82 | batch_norm = False if batch_size == 1 else True 83 | vgg = vgg_16(batch_norm, pretrained, fixed_feature) 84 | return FCN16(n_classes, vgg) 85 | def fcn16_vgg19(n_classes, batch_size, pretrained=False, fixed_feature=True): 86 | batch_norm = False if batch_size == 1 else True 87 | vgg = vgg_19(batch_norm, pretrained, fixed_feature) 88 | return FCN16(n_classes, vgg) 89 | 90 | def fcn16_resnet18(n_classes, batch_size, pretrained=False, fixed_feature=True): 91 | batch_norm = False if batch_size == 1 else True 92 | resnet = resnet18(pretrained, fixed_feature) 93 | return FCN16(n_classes, resnet) 94 | def fcn16_resnet34(n_classes, batch_size, pretrained=False, fixed_feature=True): 95 | batch_norm = False if batch_size == 1 else True 96 | resnet = resnet34(pretrained, fixed_feature) 97 | return FCN16(n_classes, resnet) 98 | def fcn16_resnet50(n_classes, batch_size, pretrained=False, fixed_feature=True): 99 | batch_norm = False if batch_size == 1 else True 100 | resnet = resnet50(pretrained, fixed_feature) 101 | return FCN16(n_classes, resnet) 102 | def fcn16_resnet101(n_classes, batch_size, pretrained=False, fixed_feature=True): 103 | batch_norm = False if batch_size == 1 else True 104 | resnet = resnet101(pretrained, fixed_feature) 105 | return FCN16(n_classes, resnet) 106 | def fcn16_resnet152(n_classes, batch_size, pretrained=False, fixed_feature=True): 107 | batch_norm = False if batch_size == 1 else True 108 | resnet = resnet152(pretrained, fixed_feature) 109 | return FCN16(n_classes, resnet) 110 | 111 | def fcn16_mobilenet_v2(n_classes, batch_size, pretrained=False, fixed_feature=True): 112 | batch_norm = False if batch_size == 1 else True 113 | mobile_net = mobilenet(pretrained, fixed_feature) 114 | return FCN16(n_classes, mobile_net) -------------------------------------------------------------------------------- /segmentation/models/fcn32.py: -------------------------------------------------------------------------------- 1 | """ 2 | FCN32 class. 3 | 4 | Library: Tensowflow 2.2.0, pyTorch 1.5.1 5 | Author: Ian Yoo 6 | Email: thyoostar@gmail.com 7 | """ 8 | from __future__ import absolute_import, division, print_function 9 | 10 | import torch 11 | from ..encoders.squeeze_extractor import * 12 | 13 | class FCN32(torch.nn.Module): 14 | 15 | def __init__(self, n_classes, pretrained_model: SqueezeExtractor): 16 | super(FCN32, self).__init__() 17 | self.pretrained_model = pretrained_model 18 | self.features = self.pretrained_model.features 19 | self.upsampling32 = nn.ConvTranspose2d(n_classes, n_classes, kernel_size=64, 20 | stride=32, bias=False) 21 | 22 | for m in self.features.modules(): 23 | if isinstance(m, nn.Conv2d): 24 | channels = m.out_channels 25 | 26 | self.classifier = nn.Sequential(nn.Conv2d(channels, n_classes, kernel_size=1), nn.Sigmoid()) 27 | self._initialize_weights() 28 | 29 | def _initialize_weights(self): 30 | for m in self.modules(): 31 | if isinstance(m, nn.Conv2d): 32 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 33 | if m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | elif isinstance(m, nn.Linear): 36 | nn.init.normal_(m.weight, 0, 0.01) 37 | nn.init.constant_(m.bias, 0) 38 | 39 | def forward(self, x): 40 | o = x 41 | for feature in self.features: 42 | o = feature(o) 43 | o = self.classifier(o) 44 | o = self.upsampling32(o) 45 | cx = int((o.shape[3] - x.shape[3]) / 2) 46 | cy = int((o.shape[2] - x.shape[2]) / 2) 47 | o = o[:, :, cy:cy + x.shape[2], cx:cx + x.shape[3]] 48 | 49 | return o 50 | 51 | from ..encoders.vgg import * 52 | from ..encoders.resnet import * 53 | from ..encoders.mobilenet import * 54 | 55 | def fcn32_vgg11(n_classes, batch_size, pretrained=False, fixed_feature=True): 56 | batch_norm = False if batch_size == 1 else True 57 | vgg = vgg_11(batch_norm, pretrained, fixed_feature) 58 | return FCN32(n_classes, vgg) 59 | def fcn32_vgg13(n_classes, batch_size, pretrained=False, fixed_feature=True): 60 | batch_norm = False if batch_size == 1 else True 61 | vgg = vgg_13(batch_norm, pretrained, fixed_feature) 62 | return FCN32(n_classes, vgg) 63 | def fcn32_vgg16(n_classes, batch_size, pretrained=False, fixed_feature=True): 64 | batch_norm = False if batch_size == 1 else True 65 | vgg = vgg_16(batch_norm, pretrained, fixed_feature) 66 | return FCN32(n_classes, vgg) 67 | def fcn32_vgg19(n_classes, batch_size, pretrained=False, fixed_feature=True): 68 | batch_norm = False if batch_size == 1 else True 69 | vgg = vgg_19(batch_norm, pretrained, fixed_feature) 70 | return FCN32(n_classes, vgg) 71 | 72 | def fcn32_resnet18(n_classes, batch_size, pretrained=False, fixed_feature=True): 73 | batch_norm = False if batch_size == 1 else True 74 | resnet = resnet18(pretrained, fixed_feature) 75 | return FCN32(n_classes, resnet) 76 | def fcn32_resnet34(n_classes, batch_size, pretrained=False, fixed_feature=True): 77 | batch_norm = False if batch_size == 1 else True 78 | resnet = resnet34(pretrained, fixed_feature) 79 | return FCN32(n_classes, resnet) 80 | def fcn32_resnet50(n_classes, batch_size, pretrained=False, fixed_feature=True): 81 | batch_norm = False if batch_size == 1 else True 82 | resnet = resnet50(pretrained, fixed_feature) 83 | return FCN32(n_classes, resnet) 84 | def fcn32_resnet101(n_classes, batch_size, pretrained=False, fixed_feature=True): 85 | batch_norm = False if batch_size == 1 else True 86 | resnet = resnet101(pretrained, fixed_feature) 87 | return FCN32(n_classes, resnet) 88 | def fcn32_resnet152(n_classes, batch_size, pretrained=False, fixed_feature=True): 89 | batch_norm = False if batch_size == 1 else True 90 | resnet = resnet152(pretrained, fixed_feature) 91 | return FCN32(n_classes, resnet) 92 | 93 | def fcn32_mobilenet_v2(n_classes, batch_size, pretrained=False, fixed_feature=True): 94 | batch_norm = False if batch_size == 1 else True 95 | mobile_net = mobilenet(pretrained, fixed_feature) 96 | return FCN32(n_classes, mobile_net) 97 | -------------------------------------------------------------------------------- /segmentation/models/fcn8.py: -------------------------------------------------------------------------------- 1 | """ 2 | FCN8 class. 3 | 4 | Library: Tensowflow 2.2.0, pyTorch 1.5.1 5 | Author: Ian Yoo 6 | Email: thyoostar@gmail.com 7 | """ 8 | from __future__ import absolute_import, division, print_function 9 | 10 | import torch 11 | from ..encoders.squeeze_extractor import * 12 | 13 | class FCN8(torch.nn.Module): 14 | 15 | def __init__(self, n_classes, pretrained_model: SqueezeExtractor): 16 | super(FCN8, self).__init__() 17 | self.features = pretrained_model.features 18 | self.copy_feature_info = pretrained_model.get_copy_feature_info() 19 | self.score_pool3 = nn.Conv2d(self.copy_feature_info[-3].out_channels, 20 | n_classes, kernel_size=1) 21 | self.score_pool4 = nn.Conv2d(self.copy_feature_info[-2].out_channels, 22 | n_classes, kernel_size=1) 23 | 24 | self.upsampling2 = nn.ConvTranspose2d(n_classes, n_classes, kernel_size=4, 25 | stride=2, bias=False) 26 | self.upsampling8 = nn.ConvTranspose2d(n_classes, n_classes, kernel_size=16, 27 | stride=8, bias=False) 28 | 29 | for m in self.features.modules(): 30 | if isinstance(m, nn.Conv2d): 31 | channels = m.out_channels 32 | 33 | self.classifier = nn.Sequential(nn.Conv2d(channels, n_classes, kernel_size=1), nn.Sigmoid()) 34 | self._initialize_weights() 35 | 36 | def _initialize_weights(self): 37 | for m in self.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 40 | if m.bias is not None: 41 | nn.init.constant_(m.bias, 0) 42 | elif isinstance(m, nn.Linear): 43 | nn.init.normal_(m.weight, 0, 0.01) 44 | nn.init.constant_(m.bias, 0) 45 | 46 | def forward(self, x): 47 | saved_pools = [] 48 | 49 | o = x 50 | for i in range(len(self.features)): 51 | o = self.features[i](o) 52 | if i == self.copy_feature_info[-3].index or\ 53 | i == self.copy_feature_info[-2].index: 54 | saved_pools.append(o) 55 | 56 | o = self.classifier(o) 57 | o = self.upsampling2(o) 58 | 59 | o2 = self.score_pool4(saved_pools[1]) 60 | o = o[:, :, 1:1 + o2.size()[2], 1:1 + o2.size()[3]] 61 | o = o + o2 62 | 63 | o = self.upsampling2(o) 64 | 65 | o2 = self.score_pool3(saved_pools[0]) 66 | o = o[:, :, 1:1 + o2.size()[2], 1:1 + o2.size()[3]] 67 | o = o + o2 68 | 69 | o = self.upsampling8(o) 70 | cx = int((o.shape[3] - x.shape[3]) / 2) 71 | cy = int((o.shape[2] - x.shape[2]) / 2) 72 | o = o[:, :, cy:cy + x.shape[2], cx:cx + x.shape[3]] 73 | 74 | return o 75 | 76 | from ..encoders.vgg import * 77 | from ..encoders.resnet import * 78 | from ..encoders.mobilenet import * 79 | 80 | def fcn8_vgg11(n_classes, batch_size, pretrained=False, fixed_feature=True): 81 | batch_norm = False if batch_size == 1 else True 82 | vgg = vgg_11(batch_norm, pretrained, fixed_feature) 83 | return FCN8(n_classes, vgg) 84 | def fcn8_vgg13(n_classes, batch_size, pretrained=False, fixed_feature=True): 85 | batch_norm = False if batch_size == 1 else True 86 | vgg = vgg_13(batch_norm, pretrained, fixed_feature) 87 | return FCN8(n_classes, vgg) 88 | def fcn8_vgg16(n_classes, batch_size, pretrained=False, fixed_feature=True): 89 | batch_norm = False if batch_size == 1 else True 90 | vgg = vgg_16(batch_norm, pretrained, fixed_feature) 91 | return FCN8(n_classes, vgg) 92 | def fcn8_vgg19(n_classes, batch_size, pretrained=False, fixed_feature=True): 93 | batch_norm = False if batch_size == 1 else True 94 | vgg = vgg_19(batch_norm, pretrained, fixed_feature) 95 | return FCN8(n_classes, vgg) 96 | 97 | def fcn8_resnet18(n_classes, batch_size, pretrained=False, fixed_feature=True): 98 | batch_norm = False if batch_size == 1 else True 99 | resnet = resnet18(pretrained, fixed_feature) 100 | return FCN8(n_classes, resnet) 101 | def fcn8_resnet34(n_classes, batch_size, pretrained=False, fixed_feature=True): 102 | batch_norm = False if batch_size == 1 else True 103 | resnet = resnet34(pretrained, fixed_feature) 104 | return FCN8(n_classes, resnet) 105 | def fcn8_resnet50(n_classes, batch_size, pretrained=False, fixed_feature=True): 106 | batch_norm = False if batch_size == 1 else True 107 | resnet = resnet50(pretrained, fixed_feature) 108 | return FCN8(n_classes, resnet) 109 | def fcn8_resnet101(n_classes, batch_size, pretrained=False, fixed_feature=True): 110 | batch_norm = False if batch_size == 1 else True 111 | resnet = resnet101(pretrained, fixed_feature) 112 | return FCN8(n_classes, resnet) 113 | def fcn8_resnet152(n_classes, batch_size, pretrained=False, fixed_feature=True): 114 | batch_norm = False if batch_size == 1 else True 115 | resnet = resnet152(pretrained, fixed_feature) 116 | return FCN8(n_classes, resnet) 117 | 118 | def fcn8_mobilenet_v2(n_classes, batch_size, pretrained=False, fixed_feature=True): 119 | batch_norm = False if batch_size == 1 else True 120 | mobile_net = mobilenet(pretrained, fixed_feature) 121 | return FCN8(n_classes, mobile_net) 122 | -------------------------------------------------------------------------------- /segmentation/models/pspnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | PSPnet class. 3 | 4 | Library: Tensowflow 2.2.0, pyTorch 1.5.1 5 | Author: Ian Yoo 6 | Email: thyoostar@gmail.com 7 | """ 8 | from __future__ import absolute_import, division, print_function 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from ..encoders.squeeze_extractor import * 13 | 14 | class PSPModule(nn.Module): 15 | def __init__(self, in_channels, out_channels=1024, pool_factors=(1, 2, 3, 6), batch_norm=True): 16 | super().__init__() 17 | self.spatial_blocks = [] 18 | for pf in pool_factors: 19 | self.spatial_blocks += [self._make_spatial_block(in_channels, pf, batch_norm)] 20 | self.spatial_blocks = nn.ModuleList(self.spatial_blocks) 21 | 22 | bottleneck = [] 23 | bottleneck += [nn.Conv2d(in_channels * (len(pool_factors) + 1), out_channels, kernel_size=1)] 24 | if batch_norm: 25 | bottleneck += [nn.BatchNorm2d(out_channels)] 26 | bottleneck += [nn.ReLU(inplace=True)] 27 | self.bottleneck = nn.Sequential(*bottleneck) 28 | 29 | def _make_spatial_block(self, in_channels, pool_factor, batch_norm): 30 | spatial_block = [] 31 | spatial_block += [nn.AdaptiveAvgPool2d(output_size=(pool_factor, pool_factor))] 32 | spatial_block += [nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False)] 33 | if batch_norm: 34 | spatial_block += [nn.BatchNorm2d(in_channels)] 35 | spatial_block += [nn.ReLU(inplace=True)] 36 | 37 | return nn.Sequential(*spatial_block) 38 | 39 | def forward(self, x): 40 | h, w = x.size(2), x.size(3) 41 | pool_outs = [x] 42 | for block in self.spatial_blocks: 43 | pooled = block(x) 44 | pool_outs += [F.upsample(pooled, size=(h, w), mode='bilinear')] 45 | o = torch.cat(pool_outs, dim=1) 46 | o = self.bottleneck(o) 47 | return o 48 | 49 | def _initialize_weights(self): 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 53 | if m.bias is not None: 54 | nn.init.constant_(m.bias, 0) 55 | elif isinstance(m, nn.BatchNorm2d): 56 | nn.init.constant_(m.weight, 1) 57 | nn.init.constant_(m.bias, 0) 58 | 59 | class PSPUpsampling(nn.Module): 60 | def __init__(self, in_channels, out_channels, batch_norm=True): 61 | super().__init__() 62 | layers = [] 63 | layers += [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)] 64 | if batch_norm: 65 | layers += [nn.BatchNorm2d(out_channels)] 66 | layers += [nn.ReLU(inplace=True)] 67 | self.layer = nn.Sequential(*layers) 68 | 69 | def forward(self, x): 70 | h, w = 2 * x.size(2), 2 * x.size(3) 71 | p = F.upsample(x, size=(h, w), mode='bilinear') 72 | return self.layer(p) 73 | 74 | def _initialize_weights(self): 75 | for m in self.modules(): 76 | if isinstance(m, nn.Conv2d): 77 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 78 | if m.bias is not None: 79 | nn.init.constant_(m.bias, 0) 80 | elif isinstance(m, nn.BatchNorm2d): 81 | nn.init.constant_(m.weight, 1) 82 | nn.init.constant_(m.bias, 0) 83 | 84 | class PSPnet(torch.nn.Module): 85 | 86 | def __init__(self, n_classes, pretrained_model: SqueezeExtractor, batch_norm=True, psp_out_feature=1024): 87 | super(PSPnet, self).__init__() 88 | self.features = pretrained_model.features 89 | 90 | # find out_channels of the top layer and define classifier 91 | for idx, m in reversed(list(enumerate(self.features.modules()))): 92 | if isinstance(m, nn.Conv2d): 93 | channels = m.out_channels 94 | break 95 | 96 | self.PSP = PSPModule(channels, out_channels=psp_out_feature, batch_norm=batch_norm) 97 | h_psp_out_feature = int(psp_out_feature / 2) 98 | q_psp_out_feature = int(psp_out_feature / 4) 99 | e_psp_out_feature = int(psp_out_feature / 8) 100 | self.upsampling1 = PSPUpsampling(psp_out_feature, h_psp_out_feature, batch_norm=batch_norm) 101 | self.upsampling2 = PSPUpsampling(h_psp_out_feature, q_psp_out_feature, batch_norm=batch_norm) 102 | self.upsampling3 = PSPUpsampling(q_psp_out_feature, e_psp_out_feature, batch_norm=batch_norm) 103 | 104 | self.classifier = nn.Sequential(nn.Conv2d(e_psp_out_feature, n_classes, kernel_size=1)) 105 | 106 | self._initialize_weights() 107 | 108 | def _initialize_weights(self): 109 | for m in self.modules(): 110 | if isinstance(m, nn.Conv2d): 111 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 112 | if m.bias is not None: 113 | nn.init.constant_(m.bias, 0) 114 | elif isinstance(m, nn.BatchNorm2d): 115 | nn.init.constant_(m.weight, 1) 116 | nn.init.constant_(m.bias, 0) 117 | elif isinstance(m, nn.Linear): 118 | nn.init.normal_(m.weight, 0, 0.01) 119 | nn.init.constant_(m.bias, 0) 120 | 121 | def forward(self, x): 122 | o = x 123 | for f in self.features: 124 | o = f(o) 125 | 126 | o = self.PSP(o) 127 | o = self.upsampling1(o) 128 | o = self.upsampling2(o) 129 | o = self.upsampling3(o) 130 | 131 | o = F.upsample(o, size=(x.shape[2], x.shape[3]), mode='bilinear') 132 | o = self.classifier(o) 133 | 134 | return o 135 | 136 | from ..encoders.vgg import * 137 | from ..encoders.resnet import * 138 | from ..encoders.mobilenet import * 139 | 140 | def pspnet_vgg11(n_classes, batch_size, pretrained=False, fixed_feature=True): 141 | batch_norm = False if batch_size == 1 else True 142 | vgg = vgg_11(batch_norm, pretrained, fixed_feature) 143 | copy_feature_info = vgg.get_copy_feature_info() 144 | squeeze_feature_idx = copy_feature_info[3].index - 1 145 | vgg.features = vgg.features[:squeeze_feature_idx] 146 | return PSPnet(n_classes, vgg, batch_norm) 147 | def pspnet_vgg13(n_classes, batch_size, pretrained=False, fixed_feature=True): 148 | batch_norm = False if batch_size == 1 else True 149 | vgg = vgg_13(batch_norm, pretrained, fixed_feature) 150 | copy_feature_info = vgg.get_copy_feature_info() 151 | squeeze_feature_idx = copy_feature_info[3].index - 1 152 | vgg.features = vgg.features[:squeeze_feature_idx] 153 | return PSPnet(n_classes, vgg, batch_norm) 154 | def pspnet_vgg16(n_classes, batch_size, pretrained=False, fixed_feature=True): 155 | batch_norm = False if batch_size == 1 else True 156 | vgg = vgg_16(batch_norm, pretrained, fixed_feature) 157 | copy_feature_info = vgg.get_copy_feature_info() 158 | squeeze_feature_idx = copy_feature_info[3].index - 1 159 | vgg.features = vgg.features[:squeeze_feature_idx] 160 | return PSPnet(n_classes, vgg, batch_norm) 161 | def pspnet_vgg19(n_classes, batch_size, pretrained=False, fixed_feature=True): 162 | batch_norm = False if batch_size == 1 else True 163 | vgg = vgg_19(batch_norm, pretrained, fixed_feature) 164 | copy_feature_info = vgg.get_copy_feature_info() 165 | squeeze_feature_idx = copy_feature_info[3].index - 1 166 | vgg.features = vgg.features[:squeeze_feature_idx] 167 | return PSPnet(n_classes, vgg, batch_norm) 168 | 169 | def pspnet_resnet18(n_classes, batch_size, pretrained=False, fixed_feature=True): 170 | batch_norm = False if batch_size == 1 else True 171 | resnet = resnet18(pretrained, fixed_feature) 172 | copy_feature_info = resnet.get_copy_feature_info() 173 | squeeze_feature_idx = copy_feature_info[3].index 174 | resnet.features = resnet.features[:squeeze_feature_idx] 175 | return PSPnet(n_classes, resnet, batch_norm) 176 | def pspnet_resnet34(n_classes, batch_size, pretrained=False, fixed_feature=True): 177 | batch_norm = False if batch_size == 1 else True 178 | resnet = resnet34(pretrained, fixed_feature) 179 | copy_feature_info = resnet.get_copy_feature_info() 180 | squeeze_feature_idx = copy_feature_info[3].index 181 | resnet.features = resnet.features[:squeeze_feature_idx] 182 | return PSPnet(n_classes, resnet, batch_norm) 183 | def pspnet_resnet50(n_classes, batch_size, pretrained=False, fixed_feature=True): 184 | batch_norm = False if batch_size == 1 else True 185 | resnet = resnet50(pretrained, fixed_feature) 186 | copy_feature_info = resnet.get_copy_feature_info() 187 | squeeze_feature_idx = copy_feature_info[3].index 188 | resnet.features = resnet.features[:squeeze_feature_idx] 189 | return PSPnet(n_classes, resnet, batch_norm) 190 | def pspnet_resnet101(n_classes, batch_size, pretrained=False, fixed_feature=True): 191 | batch_norm = False if batch_size == 1 else True 192 | resnet = resnet101(pretrained, fixed_feature) 193 | copy_feature_info = resnet.get_copy_feature_info() 194 | squeeze_feature_idx = copy_feature_info[3].index 195 | resnet.features = resnet.features[:squeeze_feature_idx] 196 | return PSPnet(n_classes, resnet, batch_norm) 197 | def pspnet_resnet152(n_classes, batch_size, pretrained=False, fixed_feature=True): 198 | batch_norm = False if batch_size == 1 else True 199 | resnet = resnet152(pretrained, fixed_feature) 200 | copy_feature_info = resnet.get_copy_feature_info() 201 | squeeze_feature_idx = copy_feature_info[3].index 202 | resnet.features = resnet.features[:squeeze_feature_idx] 203 | return PSPnet(n_classes, resnet, batch_norm) 204 | 205 | def pspnet_mobilenet_v2(n_classes, batch_size, pretrained=False, fixed_feature=True): 206 | batch_norm = False if batch_size == 1 else True 207 | mobile_net = mobilenet(pretrained, fixed_feature) 208 | copy_feature_info = mobile_net.get_copy_feature_info() 209 | squeeze_feature_idx = copy_feature_info[3].index 210 | mobile_net.features = mobile_net.features[:squeeze_feature_idx] 211 | return PSPnet(n_classes, mobile_net, batch_norm) 212 | 213 | -------------------------------------------------------------------------------- /segmentation/models/unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unet class. 3 | 4 | Library: Tensowflow 2.2.0, pyTorch 1.5.1 5 | Author: Ian Yoo 6 | Email: thyoostar@gmail.com 7 | """ 8 | from __future__ import absolute_import, division, print_function 9 | 10 | import torch 11 | from ..encoders.squeeze_extractor import * 12 | 13 | class UnetWithEncoder(torch.nn.Module): 14 | 15 | def __init__(self, n_classes, pretrained_model: SqueezeExtractor, batch_norm=True): 16 | super(UnetWithEncoder, self).__init__() 17 | self.copy_feature_info = pretrained_model.get_copy_feature_info() 18 | self.features = pretrained_model.features 19 | 20 | self.up_layer0 = self._make_up_layer(-1, batch_norm) 21 | self.up_layer1 = self._make_up_layer(-2, batch_norm) 22 | self.up_layer2 = self._make_up_layer(-3, batch_norm) 23 | self.up_layer3 = self._make_up_layer(-4, batch_norm) 24 | 25 | self.up_sampling0 = self._make_up_sampling(-1) 26 | self.up_sampling1 = self._make_up_sampling(-2) 27 | self.up_sampling2 = self._make_up_sampling(-3) 28 | self.up_sampling3 = self._make_up_sampling(-4) 29 | 30 | #find out_channels of the top layer and define classifier 31 | for f in reversed(self.up_layer3): 32 | if isinstance(f, nn.Conv2d): 33 | channels = f.out_channels 34 | break 35 | 36 | uplayer4 = [] 37 | uplayer4 += [nn.Conv2d(channels, channels, kernel_size=3, padding=1)] 38 | if batch_norm: 39 | uplayer4 += [nn.BatchNorm2d(channels)] 40 | uplayer4 += [nn.ReLU(inplace=True)] 41 | self.up_layer4 = nn.Sequential(*uplayer4) 42 | 43 | self.up_sampling4 = nn.ConvTranspose2d(channels, channels, kernel_size=4, 44 | stride=2, bias=False) 45 | self.classifier = nn.Sequential(nn.Conv2d(channels, n_classes, kernel_size=1), nn.ReLU(inplace=True)) 46 | self._initialize_weights() 47 | 48 | def _get_last_out_channels(self, features): 49 | for idx, m in reversed(list(enumerate(features.modules()))): 50 | if isinstance(m, nn.Conv2d): 51 | return m.out_channels 52 | return 0 53 | 54 | 55 | def _make_up_sampling(self, cfi_idx): 56 | if cfi_idx == -1: 57 | in_channels = self._get_last_out_channels(self.features) 58 | else: 59 | in_channels = self.copy_feature_info[cfi_idx + 1].out_channels 60 | 61 | out_channels = self.copy_feature_info[cfi_idx].out_channels 62 | return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, 63 | stride=2, bias=False) 64 | 65 | def _make_up_layer(self, cfi_idx, batch_norm): 66 | idx = self.copy_feature_info[cfi_idx].index 67 | for k in reversed(range(0, idx)): 68 | f = self.features[k] 69 | channels = self._get_last_out_channels(f) 70 | 71 | if channels == 0: 72 | continue 73 | 74 | out_channels = self.copy_feature_info[cfi_idx].out_channels 75 | in_channels = out_channels + channels # for concatenation. 76 | 77 | layer = [] 78 | layer += [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)] 79 | if batch_norm: 80 | layer += [nn.BatchNorm2d(out_channels)] 81 | layer += [nn.ReLU(inplace=True)] 82 | 83 | return nn.Sequential(*layer) 84 | 85 | assert False 86 | 87 | def _initialize_weights(self): 88 | for m in self.modules(): 89 | if isinstance(m, nn.Conv2d): 90 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 91 | if m.bias is not None: 92 | nn.init.constant_(m.bias, 0) 93 | elif isinstance(m, nn.BatchNorm2d): 94 | nn.init.constant_(m.weight, 1) 95 | nn.init.constant_(m.bias, 0) 96 | elif isinstance(m, nn.Linear): 97 | nn.init.normal_(m.weight, 0, 0.01) 98 | nn.init.constant_(m.bias, 0) 99 | 100 | def forward(self, x): 101 | copy_out = [] 102 | o = x 103 | cpi = self.copy_feature_info[-4:] 104 | copy_idx = 0 105 | 106 | for i in range(len(self.features)): 107 | o = self.features[i](o) 108 | if i == cpi[copy_idx].index - 1: 109 | copy_out.append(o) 110 | if copy_idx + 1 < len(cpi): 111 | copy_idx += 1 112 | 113 | o = self.up_sampling0(o) 114 | o = o[:, :, 1:1 + copy_out[3].size()[2], 1:1 + copy_out[3].size()[3]] 115 | o = torch.cat([o, copy_out[3]], dim=1) 116 | o = self.up_layer0(o) 117 | 118 | o = self.up_sampling1(o) 119 | o = o[:, :, 1:1 + copy_out[2].size()[2], 1:1 + copy_out[2].size()[3]] 120 | o = torch.cat([o, copy_out[2]], dim=1) 121 | o = self.up_layer1(o) 122 | 123 | o = self.up_sampling2(o) 124 | o = o[:, :, 1:1 + copy_out[1].size()[2], 1:1 + copy_out[1].size()[3]] 125 | o = torch.cat([o, copy_out[1]], dim=1) 126 | o = self.up_layer2(o) 127 | 128 | o = self.up_sampling3(o) 129 | o = o[:, :, 1:1 + copy_out[0].size()[2], 1:1 + copy_out[0].size()[3]] 130 | o = torch.cat([o, copy_out[0]], dim=1) 131 | o = self.up_layer3(o) 132 | 133 | o = self.up_sampling4(o) 134 | cx = int((o.shape[3] - x.shape[3]) / 2) 135 | cy = int((o.shape[2] - x.shape[2]) / 2) 136 | o = o[:, :, cy:cy + x.shape[2], cx:cx + x.shape[3]] 137 | o = self.up_layer4(o) 138 | o = self.classifier(o) 139 | 140 | return o 141 | 142 | cfgs = { 143 | 'A': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 1024, 1024, 'U', 512, 512, 144 | 'U', 256, 256, 'U', 128, 128, 'U', 64, 64] 145 | } 146 | 147 | class Unet(torch.nn.Module): 148 | 149 | def __init__(self, n_classes, cfg, batch_norm=True): 150 | super(Unet, self).__init__() 151 | self.features = self._make_layers(cfg, batch_norm) 152 | self.classifier = nn.Sequential(nn.Conv2d(cfg[-1], n_classes, kernel_size=1), nn.Sigmoid()) 153 | 154 | self._initialize_weights() 155 | 156 | 157 | def _make_layers(self, cfg, batch_norm): 158 | layers = [] 159 | in_channels = 3 160 | for v in cfg: 161 | if v == 'M': 162 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 163 | elif v == 'U': 164 | layers += [nn.ConvTranspose2d(in_channels, int(in_channels / 2), kernel_size=4, 165 | stride=2, bias=False)] 166 | else: 167 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 168 | if batch_norm: 169 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 170 | else: 171 | layers += [conv2d, nn.ReLU(inplace=True)] 172 | in_channels = v 173 | return nn.Sequential(*layers) 174 | 175 | 176 | 177 | def _initialize_weights(self): 178 | for m in self.modules(): 179 | if isinstance(m, nn.Conv2d): 180 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 181 | if m.bias is not None: 182 | nn.init.constant_(m.bias, 0) 183 | elif isinstance(m, nn.BatchNorm2d): 184 | nn.init.constant_(m.weight, 1) 185 | nn.init.constant_(m.bias, 0) 186 | 187 | def forward(self, x): 188 | o = x 189 | size_features = len(self.features) 190 | copys = [] 191 | for i in range(size_features): 192 | o = self.features[i](o) 193 | if isinstance(self.features[i], nn.ConvTranspose2d): 194 | copy = copys.pop() 195 | o = o[:, :, 1:1 + copy.size()[2], 1:1 + copy.size()[3]] 196 | o = torch.cat([o, copy], dim=1) 197 | if i + 1 >= size_features: 198 | continue 199 | if isinstance(self.features[i+1], nn.MaxPool2d): 200 | copys += [o] 201 | 202 | cx = int((o.shape[3] - x.shape[3]) / 2) 203 | cy = int((o.shape[2] - x.shape[2]) / 2) 204 | o = o[:, :, cy:cy + x.shape[2], cx:cx + x.shape[3]] 205 | o = self.classifier(o) 206 | 207 | return o 208 | 209 | from ..encoders.vgg import * 210 | from ..encoders.resnet import * 211 | from ..encoders.mobilenet import * 212 | 213 | def unet_vgg11(n_classes, batch_size, pretrained=False, fixed_feature=True): 214 | batch_norm = False if batch_size == 1 else True 215 | vgg = vgg_11(batch_norm, pretrained, fixed_feature) 216 | return UnetWithEncoder(n_classes, vgg, batch_norm) 217 | def unet_vgg13(n_classes, batch_size, pretrained=False, fixed_feature=True): 218 | batch_norm = False if batch_size == 1 else True 219 | vgg = vgg_13(batch_norm, pretrained, fixed_feature) 220 | return UnetWithEncoder(n_classes, vgg, batch_norm) 221 | def unet_vgg16(n_classes, batch_size, pretrained=False, fixed_feature=True): 222 | batch_norm = False if batch_size == 1 else True 223 | vgg = vgg_16(batch_norm, pretrained, fixed_feature) 224 | return UnetWithEncoder(n_classes, vgg, batch_norm) 225 | def unet_vgg19(n_classes, batch_size, pretrained=False, fixed_feature=True): 226 | batch_norm = False if batch_size == 1 else True 227 | vgg = vgg_19(batch_norm, pretrained, fixed_feature) 228 | return UnetWithEncoder(n_classes, vgg, batch_norm) 229 | 230 | def unet_resnet18(n_classes, batch_size, pretrained=False, fixed_feature=True): 231 | batch_norm = False if batch_size == 1 else True 232 | resnet = resnet18(pretrained, fixed_feature) 233 | return UnetWithEncoder(n_classes, resnet, batch_norm) 234 | def unet_resnet34(n_classes, batch_size, pretrained=False, fixed_feature=True): 235 | batch_norm = False if batch_size == 1 else True 236 | resnet = resnet34(pretrained, fixed_feature) 237 | return UnetWithEncoder(n_classes, resnet, batch_norm) 238 | def unet_resnet50(n_classes, batch_size, pretrained=False, fixed_feature=True): 239 | batch_norm = False if batch_size == 1 else True 240 | resnet = resnet50(pretrained, fixed_feature) 241 | return UnetWithEncoder(n_classes, resnet, batch_norm) 242 | def unet_resnet101(n_classes, batch_size, pretrained=False, fixed_feature=True): 243 | batch_norm = False if batch_size == 1 else True 244 | resnet = resnet101(pretrained, fixed_feature) 245 | return UnetWithEncoder(n_classes, resnet, batch_norm) 246 | def unet_resnet152(n_classes, batch_size, pretrained=False, fixed_feature=True): 247 | batch_norm = False if batch_size == 1 else True 248 | resnet = resnet152(pretrained, fixed_feature) 249 | return UnetWithEncoder(n_classes, resnet, batch_norm) 250 | 251 | def unet_mobilenet_v2(n_classes, batch_size, pretrained=False, fixed_feature=True): 252 | batch_norm = False if batch_size == 1 else True 253 | mobile_net = mobilenet(pretrained, fixed_feature) 254 | return UnetWithEncoder(n_classes, mobile_net, batch_norm) 255 | 256 | def unet(n_classes, batch_size): 257 | batch_norm = False if batch_size == 1 else True 258 | return Unet(n_classes, cfgs['A'], batch_norm) 259 | 260 | -------------------------------------------------------------------------------- /segmentation/predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | The predict functions. 3 | The main function is to write output on the color from the gray labeled image. 4 | 5 | Library: Tensowflow 2.2.0, pyTorch 1.5.1, OpenCV-Python 4.1.1.26 6 | Author: Ian Yoo 7 | Email: thyoostar@gmail.com 8 | """ 9 | from __future__ import absolute_import, division, print_function 10 | 11 | import random 12 | import cv2 13 | import torch 14 | import numpy as np 15 | import os 16 | import pathlib 17 | import six 18 | 19 | def parent(path): 20 | path = pathlib.Path(path) 21 | return str(path.parent) 22 | 23 | def exist(path): 24 | return os.path.exists(str(path)) 25 | 26 | def mkdir(path): 27 | pathlib.Path(path).mkdir(parents=True, exist_ok=True) 28 | 29 | random.seed(0) 30 | class_colors = [(random.randint(0, 255), random.randint( 31 | 0, 255), random.randint(0, 255)) for _ in range(5000)] 32 | 33 | def convert_seg_gray_to_color(input, n_classes, output_path=None, colors=class_colors): 34 | """ 35 | Convert the segmented image on gray to color. 36 | 37 | :param input: it is available to get two type(ndarray, string), string type is a file path. 38 | :param n_classes: number of the classes. 39 | :param output_path: output path. if it is None, this function return result array(ndarray) 40 | :param colors: refer to 'class_colors' format. Default: random assigned color. 41 | :return: if out_path is None, return result array(ndarray) 42 | """ 43 | if isinstance(input, six.string_types): 44 | seg = cv2.imread(input, flags=cv2.IMREAD_GRAYSCALE) 45 | elif type(input) is np.ndarray: 46 | assert len(input.shape) == 2, "Input should be h,w " 47 | seg = input 48 | 49 | height = seg.shape[0] 50 | width = seg.shape[1] 51 | 52 | seg_img = np.zeros((height, width, 3)) 53 | 54 | for c in range(n_classes): 55 | seg_arr = seg[:, :] == c 56 | seg_img[:, :, 0] += ((seg_arr) * colors[c][0]).astype('uint8') 57 | seg_img[:, :, 1] += ((seg_arr) * colors[c][1]).astype('uint8') 58 | seg_img[:, :, 2] += ((seg_arr) * colors[c][2]).astype('uint8') 59 | 60 | if output_path: 61 | cv2.imwrite(output_path, seg_img) 62 | else: 63 | return seg_img 64 | 65 | def predict(model, input_path, output_path, colors=class_colors): 66 | """ 67 | This function can save a predicted result on the color from the trained model. 68 | 69 | :param model: a network model. 70 | :param input_path: the input file path. 71 | :param output_path: the output file path. 72 | :param colors: refer to 'class_colors' format. Default: random assigned color. 73 | :return: model result. 74 | """ 75 | model.eval() 76 | 77 | img = cv2.imread(input_path, flags=cv2.IMREAD_COLOR) 78 | ori_height = img.shape[0] 79 | ori_width = img.shape[1] 80 | 81 | model_width = model.img_width 82 | model_height = model.img_height 83 | 84 | if model_width != ori_width or model_height != ori_height: 85 | img = cv2.resize(img, (model_width, model_height), interpolation=cv2.INTER_NEAREST) 86 | 87 | 88 | data = img.transpose((2, 0, 1)) 89 | data = data[None, :, :, :] 90 | data = torch.from_numpy(data).float() 91 | 92 | if next(model.parameters()).is_cuda: 93 | if not torch.cuda.is_available(): 94 | raise ValueError("A model was trained via .cuda(), but this system can not support cuda.") 95 | data = data.cuda() 96 | 97 | score = model(data) 98 | 99 | lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :] 100 | lbl_pred = lbl_pred.transpose((1, 2, 0)) 101 | n_classes = np.max(lbl_pred) 102 | lbl_pred = lbl_pred.reshape(model_height, model_width) 103 | 104 | seg_img = convert_seg_gray_to_color(lbl_pred, n_classes, colors=colors) 105 | 106 | if model_width != ori_width or model_height != ori_height: 107 | seg_img = cv2.resize(seg_img, (ori_width, ori_height), interpolation=cv2.INTER_NEAREST) 108 | 109 | if not exist(parent(output_path)): 110 | mkdir(parent(output_path)) 111 | 112 | cv2.imwrite(output_path, seg_img) 113 | 114 | return score -------------------------------------------------------------------------------- /segmentation/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/__init__.py -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/images/test/lindau_000000_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/images/test/lindau_000000_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/images/test/lindau_000001_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/images/test/lindau_000001_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/images/test/lindau_000002_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/images/test/lindau_000002_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/images/test/lindau_000003_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/images/test/lindau_000003_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/images/test/lindau_000004_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/images/test/lindau_000004_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/images/test/lindau_000005_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/images/test/lindau_000005_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/images/train/jena_000000_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/images/train/jena_000000_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/images/train/jena_000001_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/images/train/jena_000001_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/images/train/jena_000002_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/images/train/jena_000002_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/images/train/jena_000003_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/images/train/jena_000003_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/images/train/jena_000004_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/images/train/jena_000004_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/images/train/jena_000005_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/images/train/jena_000005_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/images/train/jena_000006_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/images/train/jena_000006_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/images/train/jena_000007_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/images/train/jena_000007_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/images/train/jena_000008_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/images/train/jena_000008_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/images/train/jena_000009_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/images/train/jena_000009_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/images/train/jena_000010_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/images/train/jena_000010_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/input.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/labeled/test/lindau_000000_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/labeled/test/lindau_000000_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/labeled/test/lindau_000001_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/labeled/test/lindau_000001_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/labeled/test/lindau_000002_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/labeled/test/lindau_000002_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/labeled/test/lindau_000003_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/labeled/test/lindau_000003_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/labeled/test/lindau_000004_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/labeled/test/lindau_000004_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/labeled/test/lindau_000005_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/labeled/test/lindau_000005_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/labeled/train/jena_000000_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/labeled/train/jena_000000_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/labeled/train/jena_000001_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/labeled/train/jena_000001_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/labeled/train/jena_000002_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/labeled/train/jena_000002_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/labeled/train/jena_000003_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/labeled/train/jena_000003_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/labeled/train/jena_000004_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/labeled/train/jena_000004_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/labeled/train/jena_000005_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/labeled/train/jena_000005_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/labeled/train/jena_000006_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/labeled/train/jena_000006_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/labeled/train/jena_000007_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/labeled/train/jena_000007_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/labeled/train/jena_000008_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/labeled/train/jena_000008_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/labeled/train/jena_000009_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/labeled/train/jena_000009_000019.png -------------------------------------------------------------------------------- /segmentation/test/dataset/cityspaces/labeled/train/jena_000010_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/segmentation/test/dataset/cityspaces/labeled/train/jena_000010_000019.png -------------------------------------------------------------------------------- /segmentation/test/scrach.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | 3 | from segmentation.data_loader.segmentation_dataset import SegmentationDataset 4 | from segmentation.data_loader.transform import Rescale, ToTensor 5 | from segmentation.trainer import Trainer 6 | from segmentation.predict import * 7 | from segmentation.models import all_models 8 | from util.logger import Logger 9 | 10 | train_images = r'dataset/cityspaces/images/train' 11 | test_images = r'dataset/cityspaces/images/test' 12 | train_labled = r'dataset/cityspaces/labeled/train' 13 | test_labeled = r'dataset/cityspaces/labeled/test' 14 | 15 | if __name__ == '__main__': 16 | model_name = "fcn8_vgg16" 17 | device = 'cuda' 18 | batch_size = 4 19 | n_classes = 34 20 | num_epochs = 10 21 | image_axis_minimum_size = 200 22 | pretrained = True 23 | fixed_feature = False 24 | 25 | logger = Logger(model_name=model_name, data_name='example') 26 | 27 | ### Loader 28 | compose = transforms.Compose([ 29 | Rescale(image_axis_minimum_size), 30 | ToTensor() 31 | ]) 32 | 33 | train_datasets = SegmentationDataset(train_images, train_labled, n_classes, compose) 34 | train_loader = torch.utils.data.DataLoader(train_datasets, batch_size=batch_size, shuffle=True, drop_last=True) 35 | 36 | test_datasets = SegmentationDataset(test_images, test_labeled, n_classes, compose) 37 | test_loader = torch.utils.data.DataLoader(test_datasets, batch_size=batch_size, shuffle=True, drop_last=True) 38 | 39 | ### Model 40 | model = all_models.model_from_name[model_name](n_classes, batch_size, 41 | pretrained=pretrained, 42 | fixed_feature=fixed_feature) 43 | model.to(device) 44 | 45 | ###Load model 46 | ###please check the foloder: (.segmentation/test/runs/models) 47 | #logger.load_model(model, 'epoch_15') 48 | 49 | ### Optimizers 50 | if pretrained and fixed_feature: #fine tunning 51 | params_to_update = model.parameters() 52 | print("Params to learn:") 53 | params_to_update = [] 54 | for name, param in model.named_parameters(): 55 | if param.requires_grad == True: 56 | params_to_update.append(param) 57 | print("\t", name) 58 | optimizer = torch.optim.Adadelta(params_to_update) 59 | else: 60 | optimizer = torch.optim.Adadelta(model.parameters()) 61 | 62 | ### Train 63 | #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) 64 | trainer = Trainer(model, optimizer, logger, num_epochs, train_loader, test_loader) 65 | trainer.train() 66 | 67 | 68 | #### Writing the predict result. 69 | predict(model, r'dataset/cityspaces/input.png', 70 | r'dataset/cityspaces/output.png') 71 | 72 | 73 | -------------------------------------------------------------------------------- /segmentation/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | The trainer class. 3 | 4 | Library: Tensowflow 2.2.0, pyTorch 1.5.1 5 | Author: Ian Yoo 6 | Email: thyoostar@gmail.com 7 | """ 8 | from __future__ import absolute_import, division, print_function 9 | 10 | from util.validation import * 11 | from util.logger import * 12 | 13 | try: 14 | from tqdm import tqdm 15 | from tqdm import trange 16 | except ImportError: 17 | print("tqdm and trange not found, disabling progress bars") 18 | 19 | def tqdm(iter): 20 | return iter 21 | 22 | def trange(iter): 23 | return iter 24 | 25 | TQDM_COLS = 80 26 | 27 | def cross_entropy2d(input, target): 28 | # input: (n, c, h, w), target: (n, h, w) 29 | n, c, h, w = input.size() 30 | 31 | # input: (n*h*w, c) 32 | input = input.transpose(1, 2).transpose(2, 3).contiguous() 33 | input = input[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0] 34 | input = input.view(-1, c) 35 | 36 | # target: (n*h*w,) 37 | mask = target >= 0.0 38 | target = target[mask] 39 | 40 | func_loss = torch.nn.CrossEntropyLoss() 41 | loss = func_loss(input, target) 42 | 43 | return loss 44 | 45 | 46 | class Trainer(object): 47 | 48 | def __init__(self, model, optimizer, logger, num_epochs, train_loader, 49 | test_loader=None, 50 | epoch=0, 51 | log_batch_stride=30, 52 | check_point_epoch_stride=60, 53 | scheduler=None): 54 | """ 55 | :param model: A network model to train. 56 | :param optimizer: A optimizer. 57 | :param logger: The logger for writing results to Tensorboard. 58 | :param num_epochs: iteration count. 59 | :param train_loader: pytorch's DataLoader 60 | :param test_loader: pytorch's DataLoader 61 | :param epoch: the start epoch number. 62 | :param log_batch_stride: it determines the step to write log in the batch loop. 63 | :param check_point_epoch_stride: it determines the step to save a model in the epoch loop. 64 | :param scheduler: optimizer scheduler for adjusting learning rate. 65 | """ 66 | self.cuda = torch.cuda.is_available() 67 | self.model = model 68 | self.optim = optimizer 69 | self.logger = logger 70 | self.train_loader = train_loader 71 | self.test_loader = test_loader 72 | self.num_epoches = num_epochs 73 | self.check_point_step = check_point_epoch_stride 74 | self.log_batch_stride = log_batch_stride 75 | self.scheduler = scheduler 76 | 77 | self.epoch = epoch 78 | 79 | def train(self): 80 | if not next(self.model.parameters()).is_cuda and self.cuda: 81 | raise ValueError("A model should be set via .cuda() before constructing optimizer.") 82 | 83 | for epoch in trange(self.epoch, self.num_epoches, 84 | position=0, 85 | desc='Train', ncols=TQDM_COLS): 86 | self.epoch = epoch 87 | 88 | # train 89 | self._train_epoch() 90 | 91 | # step forward to reduce the learning rate in the optimizer. 92 | if self.scheduler: 93 | self.scheduler.step() 94 | 95 | # model checkpoints 96 | if epoch%self.check_point_step == 0: 97 | self.logger.save_model_and_optimizer(self.model, 98 | self.optim, 99 | 'epoch_{}'.format(epoch)) 100 | 101 | 102 | 103 | def evaluate(self): 104 | num_batches = len(self.test_loader) 105 | self.model.eval() 106 | 107 | with torch.no_grad(): 108 | for n_batch, (sample_batched) in tqdm(enumerate(self.test_loader), 109 | total=num_batches, 110 | leave=False, 111 | desc="Valid epoch={}".format(self.epoch), 112 | ncols=TQDM_COLS): 113 | self._eval_batch(sample_batched, n_batch, num_batches) 114 | 115 | def _train_epoch(self): 116 | 117 | num_batches = len(self.train_loader) 118 | 119 | if self.test_loader: 120 | dataloader_iterator = iter(self.test_loader) 121 | 122 | for n_batch, (sample_batched) in tqdm(enumerate(self.train_loader), 123 | total=num_batches, 124 | leave=False, 125 | desc="Train epoch={}".format(self.epoch), 126 | ncols=TQDM_COLS): 127 | self.model.train() 128 | data = sample_batched['image'] 129 | target = sample_batched['annotation'] 130 | 131 | if self.cuda: 132 | data, target = data.cuda(), target.cuda() 133 | 134 | self.optim.zero_grad() 135 | 136 | torch.cuda.empty_cache() 137 | 138 | score = self.model(data) 139 | loss = cross_entropy2d(score, target) 140 | 141 | loss_data = loss.data.item() 142 | if np.isnan(loss_data): 143 | raise ValueError('loss is nan while training') 144 | 145 | loss.backward() 146 | self.optim.step() 147 | 148 | if n_batch%self.log_batch_stride != 0: 149 | continue 150 | 151 | self.logger.store_checkpoint_var('img_width', data.shape[3]) 152 | self.logger.store_checkpoint_var('img_height', data.shape[2]) 153 | 154 | self.model.img_width = data.shape[3] 155 | self.model.img_height = data.shape[2] 156 | 157 | #write logs to Tensorboard. 158 | lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :] 159 | lbl_true = target.data.cpu().numpy() 160 | acc, acc_cls, mean_iou, fwavacc = \ 161 | label_accuracy_score(lbl_true, lbl_pred, n_class=score.shape[1]) 162 | 163 | self.logger.log_train(loss, 'loss', self.epoch, n_batch, num_batches) 164 | self.logger.log_train(acc, 'acc', self.epoch, n_batch, num_batches) 165 | self.logger.log_train(acc_cls, 'acc_cls', self.epoch, n_batch, num_batches) 166 | self.logger.log_train(mean_iou, 'mean_iou', self.epoch, n_batch, num_batches) 167 | self.logger.log_train(fwavacc, 'fwavacc', self.epoch, n_batch, num_batches) 168 | 169 | #write result images when starting epoch. 170 | if n_batch == 0: 171 | log_img = self.logger.concatenate_images([lbl_pred, lbl_true], input_axis='byx') 172 | log_img = self.logger.concatenate_images([log_img, data.cpu().numpy()[:, :, :, :]]) 173 | self.logger.log_images_train(log_img, self.epoch, n_batch, num_batches, 174 | nrows=data.shape[0]) 175 | 176 | #if the trainer has the test loader, it evaluates the model using the test data. 177 | if self.test_loader: 178 | self.model.eval() 179 | with torch.no_grad(): 180 | try: 181 | sample_batched = next(dataloader_iterator) 182 | except StopIteration: 183 | dataloader_iterator = iter(self.test_loader) 184 | sample_batched = next(dataloader_iterator) 185 | 186 | self._eval_batch(sample_batched, n_batch, num_batches) 187 | 188 | 189 | def _eval_batch(self, sample_batched, n_batch, num_batches): 190 | data = sample_batched['image'] 191 | target = sample_batched['annotation'] 192 | 193 | if self.cuda: 194 | data, target = data.cuda(), target.cuda() 195 | torch.cuda.empty_cache() 196 | 197 | score = self.model(data) 198 | 199 | loss = cross_entropy2d(score, target) 200 | loss_data = loss.data.item() 201 | if np.isnan(loss_data): 202 | raise ValueError('loss is nan while training') 203 | 204 | lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :] 205 | lbl_true = target.data.cpu().numpy() 206 | acc, acc_cls, mean_iou, fwavacc = \ 207 | label_accuracy_score(lbl_true, lbl_pred, n_class=score.shape[1]) 208 | 209 | self.logger.log_test(loss, 'loss', self.epoch, n_batch, num_batches) 210 | self.logger.log_test(acc, 'acc', self.epoch, n_batch, num_batches) 211 | self.logger.log_test(acc_cls, 'acc_cls', self.epoch, n_batch, num_batches) 212 | self.logger.log_test(mean_iou, 'mean_iou', self.epoch, n_batch, num_batches) 213 | self.logger.log_test(fwavacc, 'fwavacc', self.epoch, n_batch, num_batches) 214 | 215 | if n_batch == 0: 216 | log_img = self.logger.concatenate_images([lbl_pred, lbl_true], input_axis='byx') 217 | log_img = self.logger.concatenate_images([log_img, data.cpu().numpy()[:, :, :, :]]) 218 | self.logger.log_images_test(log_img, self.epoch, n_batch, num_batches, 219 | nrows=data.shape[0]) 220 | 221 | def _write_img(self, score, target, input_img, n_batch, num_batches): 222 | lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :] 223 | lbl_true = target.data.cpu().numpy() 224 | 225 | log_img = self.logger.concatenate_images([lbl_pred, lbl_true], input_axis='byx') 226 | log_img = self.logger.concatenate_images([log_img, input_img.cpu().numpy()[:, :, :, :]]) 227 | self.logger.log_images(log_img, self.epoch, n_batch, num_batches, nrows=log_img.shape[0]) -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from setuptools import find_packages, setup 4 | 5 | 6 | def get_long_description(): 7 | with open('README.md') as f: 8 | long_description = f.read() 9 | 10 | try: 11 | import github2pypi 12 | 13 | return github2pypi.replace_url( 14 | slug='IanTaehoonYoo/semantic-segmentation-pytorch', content=long_description 15 | ) 16 | except Exception: 17 | return long_description 18 | 19 | setup(name="seg_torch", 20 | version="0.1.7", 21 | description="Semantic Segmentation with Pytorch", 22 | long_description=get_long_description(), 23 | long_description_content_type='text/markdown', 24 | author="Ian Yoo", 25 | author_email='thyoostar@gmail.com', 26 | platforms=["any"], 27 | license="MIT", 28 | url="https://github.com/IanTaehoonYoo/semantic-segmentation-pytorch", 29 | packages=find_packages(exclude=["segmentation/test/dataset"]), 30 | install_requires=[ 31 | "torch>=1.5.0", 32 | "torchvision>=0.5.0", 33 | "tensorboardX>=2.0" 34 | "opencv-python", 35 | "tqdm"], 36 | extras_require={ 37 | "tensorflow": ["tensorflow"], #this is to provide backbone models. 38 | }, 39 | classifiers=['License :: OSI Approved :: MIT License'] 40 | ) -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanTaehoonYoo/semantic-segmentation-pytorch/bebc08b9d2658d30b7fbe63afa27966f88490967/util/__init__.py -------------------------------------------------------------------------------- /util/checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | CheckpointHandler. 3 | 4 | Library: Tensowflow 2.2.0, pyTorch 1.5.1 5 | Author: Ian Yoo 6 | Email: thyoostar@gmail.com 7 | """ 8 | 9 | import torch 10 | 11 | class CheckpointHandler: 12 | 13 | def store_var(self, var_name, value, exist_fail=False): 14 | if exist_fail is True and hasattr(self, var_name): 15 | raise Exception("var_name='{}' already exists".format(var_name)) 16 | else: 17 | setattr(self, var_name, value) 18 | 19 | def get_var(self, var_name): 20 | if hasattr(self, var_name): 21 | value = getattr(self, var_name) 22 | return value 23 | else: 24 | return False 25 | 26 | def save_checkpoint(self, checkpoint_path, model, optimizer=None): 27 | if type(model) == torch.nn.DataParallel: 28 | # converting a DataParallel model to be able load later without DataParallel 29 | self.model_state_dict = model.module.state_dict() 30 | else: 31 | self.model_state_dict = model.state_dict() 32 | 33 | if optimizer: 34 | self.optimizer_state_dict = optimizer.state_dict() 35 | 36 | torch.save(self, checkpoint_path) 37 | 38 | @staticmethod 39 | def load_checkpoint(checkpoint_path, map_location='cpu'): 40 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 41 | return checkpoint -------------------------------------------------------------------------------- /util/imshow.py: -------------------------------------------------------------------------------- 1 | """ 2 | imshow function. This is useful to see an tiled image. 3 | 4 | Library: Tensowflow 2.2.0, pyTorch 1.5.1, OpenCV-Python 4.1.1.26, PIL 5 | Author: Ian Yoo 6 | Email: thyoostar@gmail.com 7 | """ 8 | from __future__ import absolute_import, division, print_function 9 | 10 | import numpy as np 11 | import torch 12 | import os 13 | if os.system('python -c "import matplotlib.pyplot as plt;plt.figure()"') != 0: # non gui support 14 | print("non gui system. use Agg instead") 15 | import matplotlib # https://stackoverflow.com/questions/43003758/ 16 | matplotlib.use("Agg") # matplotlib-is-throwing-segmentation-fault-when-running-on-non-gui-machineweb-se 17 | import matplotlib.pyplot as plt 18 | from PIL import Image 19 | import pathlib 20 | 21 | def parent(path): 22 | path = pathlib.Path(path) 23 | return str(path.parent) 24 | 25 | def exist(path): 26 | return os.path.exists(str(path)) 27 | 28 | def mkdir(path): 29 | pathlib.Path(path).mkdir(parents=True, exist_ok=True) 30 | 31 | def normalize_img(img, vmin=None, vmax=None): 32 | """ 33 | :param img: Tensor, np 34 | :param vmin: 35 | :param vmax: 36 | :return: Tensor, np, float32, return the same dimension 37 | """ 38 | if isinstance(img, np.ndarray): 39 | img = img.astype(np.float32) 40 | if vmin is None: 41 | vmin = img.min() 42 | if vmax is None: 43 | vmax = img.max() 44 | img = np.clip(img, vmin, vmax) 45 | img = (img - vmin) / (vmax - vmin) 46 | img = np.clip(img, 0, 1) # numeric error 대비 47 | return img 48 | elif isinstance(img, torch.Tensor): 49 | img = img.type(torch.float32) 50 | if vmin is None: 51 | vmin = img.min() 52 | if vmax is None: 53 | vmax = img.max() 54 | img = torch.clamp(img, vmin, vmax) 55 | img = (img - vmin) / (vmax - vmin) 56 | img = torch.clamp(img, 0, 1) 57 | return img 58 | else: 59 | raise ValueError 60 | 61 | def imshow(*args, nx=None, vmin=None, vmax=None, path=None, is_color=False, normalize_uint8=False, title=None): 62 | """ 63 | :param args: available as follow 64 | Tensor: 65 | float32, bcyx | cyx 66 | np: 67 | uint8, byxc | yxc 68 | float32, byxc | yxc 69 | PIL image 70 | :param nx: image counts on cols. 71 | (default) if counts bigger than ten, nx is fixed ten. 72 | :param vmin: (default) minimum value of the image. 73 | :param vmax: (default) maximum value of the image. 74 | :param path: saved path. if it is null, not save. 75 | :param is_color: gray or color. 76 | :param normalize_uint8: if it is [np.uint8], this param decides to divide 255. 77 | :param title: firgure name 78 | """ 79 | # args is turned into one list. 80 | 81 | imgs = [] 82 | for arg in args: 83 | if isinstance(arg, list): 84 | imgs += arg 85 | else: 86 | imgs.append(arg) 87 | 88 | assert len(imgs) >= 1 89 | 90 | for i in range(len(imgs)): 91 | if isinstance(imgs[i], np.ndarray): 92 | # if normalize_uint8 is true, values is normalized [0,1] 93 | if normalize_uint8 and imgs[i].dtype == np.uint8: 94 | imgs[i] = np.clip(imgs[i].astype(np.float32) / 255, 0, 1) 95 | else: 96 | imgs[i] = imgs[i].astype(np.float32) 97 | if len(imgs[i].shape) == 2: # np, yx --> 1yx1 98 | imgs[i] = imgs[i][None, :, :, None] 99 | elif len(imgs[i].shape) == 3: # np, yxc --> 1yxc 100 | imgs[i] = imgs[i][None, :, :, :] 101 | elif len(imgs[i].shape) == 4: 102 | pass 103 | else: 104 | raise ValueError 105 | elif isinstance(imgs[i], torch.Tensor): 106 | imgs[i] = imgs[i].cpu().detach().numpy() 107 | imgs[i] = imgs[i].astype(np.float32) 108 | if len(imgs[i].shape) == 2: # Tensor, yx --> 1yx1 109 | imgs[i] = imgs[i][None, :, :, None] 110 | elif len(imgs[i].shape) == 3: # Tensor, cyx --> 1yxc 111 | imgs[i] = imgs[i][None, :, :, :] 112 | imgs[i] = np.transpose(imgs[i], [0, 2, 3, 1]) 113 | elif len(imgs[i].shape) == 4: 114 | imgs[i] = np.transpose(imgs[i], [0, 2, 3, 1]) 115 | else: 116 | raise ValueError 117 | elif isinstance(imgs[i], Image.Image): 118 | imgs[i] = np.array(imgs[i]).astype(np.float32) 119 | imgs[i] = np.clip(imgs[i] / 255, 0, 1) 120 | if len(imgs[i].shape) == 2: # PIL img, yx --> 1yx1 121 | imgs[i] = imgs[i][None, :, :, None] 122 | elif len(imgs[i].shape) == 3: # PIL img, yxc --> 1yxc 123 | imgs[i] = imgs[i][None, :, :, :3] # if it has alpha, it is trimmed. 124 | else: 125 | raise ValueError 126 | else: 127 | raise ValueError 128 | 129 | # imgs's byxc must be matched 130 | img = np.concatenate(imgs) 131 | 132 | # check color 133 | b, y, x, c = img.shape 134 | if is_color and c != 3 and c != 4: 135 | raise ValueError 136 | 137 | # np, bcyx, float32 138 | img = np.transpose(img, [0, 3, 1, 2]) 139 | 140 | # set nx automatically from image counts 141 | num_img = img.shape[0] if is_color else img.shape[0] * img.shape[1] 142 | if nx is None: 143 | if num_img < 10: 144 | nx = num_img 145 | else: 146 | nx = 10 147 | 148 | if not is_color: # gray 149 | ny = int(np.ceil(np.float32(b * c) / nx)) 150 | img = img.reshape(b * c, y, x) 151 | black = np.zeros([ny * nx - b * c, y, x], np.float32) 152 | img = np.concatenate([img, black]) # ny*nx,y,x 153 | img = img.reshape(ny, nx, y, x) 154 | img = img.transpose(0, 2, 1, 3) # ny,y,nx,x 155 | img = img.reshape(ny * y, nx * x) # ny*y,nx*x complete 156 | else: # color: if img is color image, it'll create three channels of the gray image. 157 | ny = int(np.ceil(np.float32(b) / nx)) 158 | black = np.zeros([ny * nx - b, c, y, x], np.float32) 159 | img = np.concatenate([img, black]) # ny*nx,c,y,x 160 | img = img.reshape(ny, nx, c, y, x) 161 | img = img.transpose(0, 3, 1, 4, 2) # ny,y,nx,x,c 162 | img = img.reshape(ny* y, nx * x, c) # ny*y,nx*x,c complete 163 | 164 | # rescale 165 | img = normalize_img(img, vmin, vmax) 166 | 167 | # plot or save: 168 | if path == None: 169 | plt.figure(num=title) 170 | if is_color: 171 | fig = plt.imshow(img, interpolation="nearest", vmin=0, vmax=1) 172 | else: 173 | fig = plt.imshow(img, cmap="gray", interpolation="nearest", vmin=0, vmax=1) 174 | fig.axes.get_xaxis().set_visible(False) 175 | fig.axes.get_yaxis().set_visible(False) 176 | plt.pause(0.001) 177 | plt.show(block=False) 178 | plt.tight_layout() 179 | else: 180 | if not exist(parent(path)): 181 | mkdir(parent(path)) 182 | img = np.clip(img * 255, 0, 255).astype(np.uint8) 183 | img = Image.fromarray(img) 184 | img.save(path) 185 | 186 | 187 | def imshowc(*args, nx=None, vmin=0, vmax=1, path=None, normalize_uint8=True, title=None): 188 | """ 189 | Drawing a color image. 190 | 191 | :param args: available as follow 192 | Tensor: 193 | float32, bcyx | cyx 194 | np: 195 | uint8, bcyx | cyx 196 | float32, bcyx | cyx 197 | PIL image 198 | :param nx: image counts on cols. 199 | (default) if counts bigger than ten, nx is fixed ten. 200 | :param vmin: (default) minimum value of the image. 201 | :param vmax: (default) maximum value of the image. 202 | :param path: saved path. if it is null, not save. 203 | :param is_color: gray or color. 204 | :param normalize_uint8: if it is [np.uint8], this param decides to divide 255. 205 | :param title: firgure name 206 | """ 207 | imshow(*args, nx=nx, vmin=vmin, vmax=vmax, path=path, is_color=True, normalize_uint8=normalize_uint8, title=title) 208 | -------------------------------------------------------------------------------- /util/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | The logger class. This class write logs to Tensorboard mostly. 3 | Also, it can save images and will be stored in '.runs' path 4 | 5 | Library: Tensowflow 2.2.0, pyTorch 1.5.1 6 | Author: Ian Yoo 7 | Email: thyoostar@gmail.com 8 | """ 9 | from __future__ import absolute_import, division, print_function 10 | 11 | import os 12 | import numpy as np 13 | import errno 14 | import torchvision.utils as vutils 15 | from tensorboardX import SummaryWriter 16 | from matplotlib import pyplot as plt 17 | import pathlib 18 | import datetime 19 | from util.checkpoint import * 20 | import copy 21 | 22 | class Logger: 23 | 24 | def __init__(self, model_name, data_name): 25 | self.model_name = model_name 26 | self.data_name = data_name 27 | 28 | self.comment = '{}_{}'.format(model_name, data_name) 29 | self.data_subdir = '{}/{}'.format(model_name, data_name) 30 | 31 | current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_") 32 | 33 | train_log_dir = r'runs/' + current_time + self.comment + r'/train' 34 | test_log_dir = r'runs/' + current_time + self.comment + r'/test' 35 | 36 | self.hdl_chkpoint = CheckpointHandler() 37 | 38 | # TensorBoard 39 | self.writer_train = SummaryWriter(train_log_dir, comment=self.comment) 40 | self.writer_test = SummaryWriter(test_log_dir, comment=self.comment) 41 | 42 | def log_train(self, scalar, title, epoch, n_batch, num_batches): 43 | step = Logger._step(epoch, n_batch, num_batches) 44 | self.writer_train.add_scalar( 45 | '{}/{}'.format(self.comment, title), scalar, step) 46 | 47 | def log_test(self, scalar, title, epoch, n_batch, num_batches): 48 | step = Logger._step(epoch, n_batch, num_batches) 49 | self.writer_test.add_scalar( 50 | '{}/{}'.format(self.comment, title), scalar, step) 51 | 52 | def concatenate_images(self, *args, input_axis='bcyx', normalize_uint8=False): 53 | """ 54 | This function concatenate images and return the result. 55 | 56 | :param images: available follow 57 | Tensor: float32 58 | np: uint8, int64, float32 59 | :param input_axis: if the input_axis is 'byxc', it transpose axis to 'bcyx'. 60 | available (bcyx | byxc | cyx | yxc | byx) 61 | :param normalize_uint8: if dtype is [np.uint8], values are divided by 255. 62 | :return: images 63 | """ 64 | imgs = [] 65 | for arg in args: 66 | if isinstance(arg, list): 67 | imgs += arg 68 | else: 69 | imgs.append(arg) 70 | 71 | assert len(imgs) >= 1 72 | 73 | for i in range(len(imgs)): 74 | if isinstance(imgs[i], np.ndarray): 75 | if imgs[i].dtype == np.uint8 or imgs[i].dtype == np.int64: 76 | if normalize_uint8: 77 | imgs[i] = np.clip(imgs[i].astype(np.float32) / 255, 0.0, 1.0) 78 | else: 79 | imgs[i] = imgs[i].astype(np.float32) 80 | imgs[i] = torch.from_numpy(imgs[i]) 81 | 82 | if imgs[i].dtype != torch.float32: 83 | imgs[i] = imgs[i].float() 84 | 85 | if len(imgs[i].shape) == 2: # Tensor, yx --> 11yx 86 | imgs[i] = imgs[i][None, None, :, :] 87 | elif len(imgs[i].shape) == 3: # Tensor, cyx --> 1cyx 88 | imgs[i] = imgs[i][None, :, :, :] 89 | 90 | # swap axis to 'bcyx' 91 | if input_axis == 'byxc' or input_axis == 'yxc': 92 | imgs[i] = imgs[i].transpose(1, 3) 93 | imgs[i] = imgs[i].transpose(1, 2) 94 | elif input_axis == 'byx': 95 | imgs[i] = imgs[i].transpose(0, 1) 96 | 97 | if imgs[i].shape[1] == 1: 98 | imgs[i] = imgs[i].repeat((1, 3, 1, 1)) / 3.0 99 | 100 | return torch.cat(imgs) 101 | 102 | def log_images_train(self, images, epoch, n_batch, num_batches, input_axis='bcyx', 103 | nrows=8, padding=2, pad_value=1, normalize=True, normalize_uint8=False): 104 | """ 105 | This function writes images to Tensorboard and save the file at [./data/images/] 106 | 107 | :param images: available follow 108 | Tensor: float32 109 | np: uint8, int64, float32 110 | :param epoch: epoch. 111 | :param n_batch: batch index. 112 | :param num_batches: batch counts. 113 | :param nrows: grid's rows on the image. 114 | :param padding: amount of padding. 115 | :param pad_value: padding scalar value, the range [0, 1]. 116 | :param input_axis: if the input_axis is 'byxc', it transpose axis to 'bcyx', available as follow 117 | (bcyx | byxc | cyx | yxc | byx) 118 | :param normalize: normalize image to the range [0, 1]. 119 | :param normalize_uint8: if dtype is [np.uint8], values are divided by 255. 120 | """ 121 | 122 | img_name, grid, step = self._log_images(images, epoch, n_batch, num_batches, 123 | input_axis, nrows, padding, pad_value, normalize, 124 | comment='train') 125 | 126 | # Add images to tensorboard 127 | self.writer_train.add_image(img_name, grid, step) 128 | 129 | def log_images_test(self, images, epoch, n_batch, num_batches, input_axis='bcyx', 130 | nrows=8, padding=2, pad_value=1, normalize=True, normalize_uint8=False): 131 | """ 132 | This function writes images to Tensorboard and save the file at [./data/images/] 133 | 134 | :param images: available follow 135 | Tensor: float32 136 | np: uint8, int64, float32 137 | :param epoch: epoch. 138 | :param n_batch: batch index. 139 | :param num_batches: batch counts. 140 | :param nrows: grid's rows on the image. 141 | :param padding: amount of padding. 142 | :param pad_value: padding scalar value, the range [0, 1]. 143 | :param input_axis: if the input_axis is 'byxc', it transpose axis to 'bcyx', available as follow 144 | (bcyx | byxc | cyx | yxc | byx) 145 | :param normalize: normalize image to the range [0, 1]. 146 | :param normalize_uint8: if dtype is [np.uint8], values are divided by 255. 147 | """ 148 | 149 | img_name, grid, step = self._log_images(images, epoch, n_batch, num_batches, 150 | input_axis, nrows, padding, pad_value, normalize, 151 | comment='test') 152 | 153 | # Add images to tensorboard 154 | self.writer_test.add_image(img_name, grid, step) 155 | 156 | def _log_images(self, images, epoch, n_batch, num_batches, input_axis='bcyx', 157 | nrows=8, padding=2, pad_value=1, normalize=True, normalize_uint8=False, comment=''): 158 | 159 | if isinstance(images, np.ndarray): 160 | if images.dtype == np.uint8 or images.dtype == np.int64: 161 | if normalize_uint8: 162 | images = np.clip(images.astype(np.float32) / 255, 0.0, 1.0) 163 | else: 164 | images = images.astype(np.float32) 165 | images = torch.from_numpy(images) 166 | 167 | if len(images.shape) == 2: # Tensor, yx --> 11yx 168 | images = images[None, None, :, :] 169 | elif len(images.shape) == 3: # Tensor, cyx --> 1cyx 170 | images = images[None, :, :, :] 171 | 172 | # swap axis to 'bcyx' 173 | if input_axis == 'byxc' or input_axis == 'yxc': 174 | images = images.transpose(1, 3) 175 | images = images.transpose(1, 2) 176 | elif input_axis == 'byx': 177 | images = images.transpose(0, 1) 178 | 179 | step = Logger._step(epoch, n_batch, num_batches) 180 | img_name = '{}/images{}'.format(self.comment, '') 181 | 182 | # Make grid from image tensor 183 | if images.shape[0] < nrows: 184 | nrows = images.shape[0] 185 | 186 | grid = vutils.make_grid(images, nrow=nrows, normalize=normalize, 187 | scale_each=True, pad_value=pad_value, padding=padding) 188 | 189 | # Save plots 190 | self._save_torch_images(grid, epoch, n_batch, comment) 191 | 192 | return img_name, grid, step 193 | 194 | def _save_torch_images(self, grid, epoch, n_batch, comment=''): 195 | out_dir = './runs/images/{}'.format(self.data_subdir) 196 | Logger._make_dir(out_dir) 197 | 198 | # Save squared 199 | fig = plt.figure() 200 | plt.imshow(np.moveaxis(grid.numpy(), 0, -1)) 201 | plt.axis('off') 202 | if comment: 203 | fig.savefig('{}/{}_epoch_{}_batch_{}.png'.format(out_dir, comment, epoch, n_batch)) 204 | else: 205 | fig.savefig('{}/epoch_{}_batch_{}.png'.format(out_dir, comment, epoch, n_batch)) 206 | plt.close() 207 | 208 | def store_checkpoint_var(self, key, value): 209 | self.hdl_chkpoint.store_var(key, value) 210 | 211 | def save_model(self, model, file_name): 212 | out_dir = './runs/models/{}'.format(self.data_subdir) 213 | if not Logger._exist(out_dir): 214 | Logger._make_dir(out_dir) 215 | 216 | self.hdl_chkpoint.save_checkpoint('{}/{}'.format(out_dir, file_name)) 217 | 218 | def save_model_and_optimizer(self, model, optim, file_name): 219 | out_dir = './runs/models/{}'.format(self.data_subdir) 220 | if not Logger._exist(out_dir): 221 | Logger._make_dir(out_dir) 222 | 223 | self.hdl_chkpoint.save_checkpoint('{}/{}'.format(out_dir, file_name), model, optim) 224 | 225 | def load_model(self, model, file_name): 226 | dir = './runs/models/{}'.format(self.data_subdir) 227 | assert Logger._exist(dir) 228 | 229 | self.hdl_chkpoint = self.hdl_chkpoint.load_checkpoint('{}/{}'.format(dir, file_name)) 230 | 231 | model.load_state_dict(self.hdl_chkpoint.model_state_dict) 232 | if hasattr(self.hdl_chkpoint, '__dict__'): 233 | for k in self.hdl_chkpoint.__dict__: 234 | if k == 'model_state_dict' or k == 'optimizer_state_dict': 235 | continue 236 | attr_copy = copy.deepcopy(getattr(self.hdl_chkpoint, k)) 237 | setattr(model, k, attr_copy) 238 | 239 | # def load_model_and_optimizer(self, model, optim, file_name): 240 | # dir = './runs/models/{}'.format(self.data_subdir) 241 | # assert Logger._exist(dir) 242 | # 243 | # self.hdl_chkpoint = self.hdl_chkpoint.load_checkpoint('{}/{}'.format(dir, file_name)) 244 | # 245 | # model.load_state_dict(self.hdl_chkpoint.model_state_dict) 246 | # optim.load_state_dict(self.hdl_chkpoint.optimizer_state_dict) 247 | # if hasattr(self.hdl_chkpoint, '__dict__'): 248 | # for k in self.hdl_chkpoint.__dict__: 249 | # if k == 'model_state_dict' or k == 'optimizer_state_dict': 250 | # continue 251 | # attr_copy = copy.deepcopy(getattr(self.hdl_chkpoint, k)) 252 | # setattr(model, k, attr_copy) 253 | 254 | def close(self): 255 | self.writer.close() 256 | 257 | @staticmethod 258 | def _step(epoch, n_batch, num_batches): 259 | return epoch * num_batches + n_batch 260 | 261 | @staticmethod 262 | def _make_dir(directory): 263 | try: 264 | os.makedirs(directory) 265 | except OSError as e: 266 | if e.errno != errno.EEXIST: 267 | raise 268 | 269 | @staticmethod 270 | def _parent(path): 271 | path = pathlib.Path(path) 272 | return str(path.parent) 273 | 274 | @staticmethod 275 | def _exist(path): 276 | return os.path.exists(str(path)) 277 | -------------------------------------------------------------------------------- /util/validation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Validation functions. Follow functions are able to get scalar values for the validation. 3 | 4 | Library: Tensowflow 2.2.0, pyTorch 1.5.1 5 | Author: Ian Yoo 6 | Email: thyoostar@gmail.com 7 | """ 8 | from __future__ import absolute_import, division, print_function 9 | 10 | import numpy as np 11 | 12 | def label_accuracy_score(label_trues, label_preds, n_class): 13 | """ 14 | :param label_trues: 15 | :param label_preds: 16 | :param n_class: 17 | :return: accuracy score and evaluation results 18 | (overall accuracy, mean accuracy, mean IoU, fwavacc) 19 | """ 20 | 21 | hist = np.zeros((n_class, n_class)) 22 | for lt, lp in zip(label_trues, label_preds): 23 | hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) 24 | acc = np.diag(hist).sum() / hist.sum() 25 | with np.errstate(divide='ignore', invalid='ignore'): 26 | acc_cls = np.diag(hist) / hist.sum(axis=1) 27 | acc_cls = np.nanmean(acc_cls) 28 | with np.errstate(divide='ignore', invalid='ignore'): 29 | iu = np.diag(hist) / ( 30 | (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)).astype(np.float32) 31 | ) 32 | 33 | mean_iou = np.nanmean(iu) 34 | freq = hist.sum(axis=1) / hist.sum() 35 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 36 | return acc, acc_cls, mean_iou, fwavacc 37 | 38 | def _fast_hist(label_true, label_pred, n_class): 39 | mask = (label_true >= 0) & (label_true < n_class) 40 | hist = np.bincount( 41 | n_class * label_true[mask].astype(int) + 42 | label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class) 43 | return hist 44 | --------------------------------------------------------------------------------