├── .gitignore ├── LICENSE ├── README.md ├── configs ├── fcn8s_pascal.yml └── frrnB_cityscapes.yml ├── ptsemseg ├── __init__.py ├── augmentations │ ├── __init__.py │ └── augmentations.py ├── caffe_pb2.py ├── loader │ ├── __init__.py │ ├── ade20k_loader.py │ ├── camvid_loader.py │ ├── cityscapes_loader.py │ ├── mapillary_vistas_loader.py │ ├── mit_sceneparsing_benchmark_loader.py │ ├── nyuv2_loader.py │ ├── pascal_voc_loader.py │ └── sunrgbd_loader.py ├── loss │ ├── __init__.py │ └── loss.py ├── metrics.py ├── models │ ├── __init__.py │ ├── fcn.py │ ├── frrn.py │ ├── icnet.py │ ├── linknet.py │ ├── pspnet.py │ ├── refinenet.py │ ├── segnet.py │ ├── unet.py │ └── utils.py ├── optimizers │ └── __init__.py ├── schedulers │ ├── __init__.py │ └── schedulers.py └── utils.py ├── requirements.txt ├── test.py ├── train.py └── validate.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Torch Models 7 | *.pkl 8 | *.pth 9 | current_train.py 10 | video_test*.py 11 | *.swp 12 | data 13 | ckpt 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | local_test.py 21 | .DS_STORE 22 | .idea/ 23 | .vscode/ 24 | env/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *,cover 59 | .hypothesis/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # IPython Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # dotenv 92 | .env 93 | 94 | # virtualenv 95 | venv/ 96 | ENV/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Meet Pragnesh Shah 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-semseg 2 | 3 | [![license](https://img.shields.io/github/license/mashape/apistatus.svg)](https://github.com/meetshah1995/pytorch-semseg/blob/master/LICENSE) 4 | [![pypi](https://img.shields.io/pypi/v/pytorch_semseg.svg)](https://pypi.python.org/pypi/pytorch-semseg/0.1.2) 5 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.1185075.svg)](https://doi.org/10.5281/zenodo.1185075) 6 | 7 | 8 | 9 | ## Semantic Segmentation Algorithms Implemented in PyTorch 10 | 11 | This repository aims at mirroring popular semantic segmentation architectures in PyTorch. 12 | 13 | 14 |

15 | 16 | 17 |

18 | 19 | 20 | ### Networks implemented 21 | 22 | * [PSPNet](https://arxiv.org/abs/1612.01105) - With support for loading pretrained models w/o caffe dependency 23 | * [ICNet](https://arxiv.org/pdf/1704.08545.pdf) - With optional batchnorm and pretrained models 24 | * [FRRN](https://arxiv.org/abs/1611.08323) - Model A and B 25 | * [FCN](https://arxiv.org/abs/1411.4038) - All 1 (FCN32s), 2 (FCN16s) and 3 (FCN8s) stream variants 26 | * [U-Net](https://arxiv.org/abs/1505.04597) - With optional deconvolution and batchnorm 27 | * [Link-Net](https://codeac29.github.io/projects/linknet/) - With multiple resnet backends 28 | * [Segnet](https://arxiv.org/abs/1511.00561) - With Unpooling using Maxpool indices 29 | 30 | 31 | #### Upcoming 32 | 33 | * [E-Net](https://arxiv.org/abs/1606.02147) 34 | * [RefineNet](https://arxiv.org/abs/1611.06612) 35 | 36 | ### DataLoaders implemented 37 | 38 | * [CamVid](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) 39 | * [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/segexamples/index.html) 40 | * [ADE20K](http://groups.csail.mit.edu/vision/datasets/ADE20K/) 41 | * [MIT Scene Parsing Benchmark](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip) 42 | * [Cityscapes](https://www.cityscapes-dataset.com/) 43 | * [NYUDv2](http://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) 44 | * [Sun-RGBD](http://rgbd.cs.princeton.edu/) 45 | 46 | 47 | ### Requirements 48 | 49 | * pytorch >=0.4.0 50 | * torchvision ==0.2.0 51 | * scipy 52 | * tqdm 53 | * tensorboardX 54 | 55 | #### One-line installation 56 | 57 | `pip install -r requirements.txt` 58 | 59 | ### Data 60 | 61 | * Download data for desired dataset(s) from list of URLs [here](https://meetshah1995.github.io/semantic-segmentation/deep-learning/pytorch/visdom/2017/06/01/semantic-segmentation-over-the-years.html#sec_datasets). 62 | * Extract the zip / tar and modify the path appropriately in your `config.yaml` 63 | 64 | 65 | ### Usage 66 | 67 | **Setup config file** 68 | 69 | ```yaml 70 | # Model Configuration 71 | model: 72 | arch: [options: 'fcn[8,16,32]s, unet, segnet, pspnet, icnet, icnetBN, linknet, frrn[A,B]' 73 | : 74 | 75 | # Data Configuration 76 | data: 77 | dataset: [options: 'pascal, camvid, ade20k, mit_sceneparsing_benchmark, cityscapes, nyuv2, sunrgbd, vistas'] 78 | train_split: 79 | val_split: 80 | img_rows: 512 81 | img_cols: 1024 82 | path: 83 | : 84 | 85 | # Training Configuration 86 | training: 87 | n_workers: 64 88 | train_iters: 35000 89 | batch_size: 16 90 | val_interval: 500 91 | print_interval: 25 92 | loss: 93 | name: [options: 'cross_entropy, bootstrapped_cross_entropy, multi_scale_crossentropy'] 94 | : 95 | 96 | # Optmizer Configuration 97 | optimizer: 98 | name: [options: 'sgd, adam, adamax, asgd, adadelta, adagrad, rmsprop'] 99 | lr: 1.0e-3 100 | : 101 | 102 | # Warmup LR Configuration 103 | warmup_iters: 104 | mode: <'constant' or 'linear' for warmup'> 105 | gamma: 106 | 107 | # Augmentations Configuration 108 | augmentations: 109 | gamma: x #[gamma varied in 1 to 1+x] 110 | hue: x #[hue varied in -x to x] 111 | brightness: x #[brightness varied in 1-x to 1+x] 112 | saturation: x #[saturation varied in 1-x to 1+x] 113 | contrast: x #[contrast varied in 1-x to 1+x] 114 | rcrop: [h, w] #[crop of size (h,w)] 115 | translate: [dh, dw] #[reflective translation by (dh, dw)] 116 | rotate: d #[rotate -d to d degrees] 117 | scale: [h,w] #[scale to size (h,w)] 118 | ccrop: [h,w] #[center crop of (h,w)] 119 | hflip: p #[flip horizontally with chance p] 120 | vflip: p #[flip vertically with chance p] 121 | 122 | # LR Schedule Configuration 123 | lr_schedule: 124 | name: [options: 'constant_lr, poly_lr, multi_step, cosine_annealing, exp_lr'] 125 | : 126 | 127 | # Resume from checkpoint 128 | resume: 129 | ``` 130 | 131 | **To train the model :** 132 | 133 | ``` 134 | python train.py [-h] [--config [CONFIG]] 135 | 136 | --config Configuration file to use 137 | ``` 138 | 139 | **To validate the model :** 140 | 141 | ``` 142 | usage: validate.py [-h] [--config [CONFIG]] [--model_path [MODEL_PATH]] 143 | [--eval_flip] [--measure_time] 144 | 145 | --config Config file to be used 146 | --model_path Path to the saved model 147 | --eval_flip Enable evaluation with flipped image | True by default 148 | --measure_time Enable evaluation with time (fps) measurement | True 149 | by default 150 | ``` 151 | 152 | **To test the model w.r.t. a dataset on custom images(s):** 153 | 154 | ``` 155 | python test.py [-h] [--model_path [MODEL_PATH]] [--dataset [DATASET]] 156 | [--dcrf [DCRF]] [--img_path [IMG_PATH]] [--out_path [OUT_PATH]] 157 | 158 | --model_path Path to the saved model 159 | --dataset Dataset to use ['pascal, camvid, ade20k etc'] 160 | --dcrf Enable DenseCRF based post-processing 161 | --img_path Path of the input image 162 | --out_path Path of the output segmap 163 | ``` 164 | 165 | 166 | **If you find this code useful in your research, please consider citing:** 167 | 168 | ``` 169 | @article{mshahsemseg, 170 | Author = {Meet P Shah}, 171 | Title = {Semantic Segmentation Architectures Implemented in PyTorch.}, 172 | Journal = {https://github.com/meetshah1995/pytorch-semseg}, 173 | Year = {2017} 174 | } 175 | ``` 176 | 177 | -------------------------------------------------------------------------------- /configs/fcn8s_pascal.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: fcn8s 3 | data: 4 | dataset: pascal 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 'same' 8 | img_cols: 'same' 9 | path: /private/home/meetshah/datasets/VOC/060817/VOCdevkit/VOC2012/ 10 | sbd_path: /private/home/meetshah/datasets/VOC/benchmark_RELEASE/ 11 | training: 12 | train_iters: 300000 13 | batch_size: 1 14 | val_interval: 1000 15 | n_workers: 16 16 | print_interval: 50 17 | optimizer: 18 | name: 'sgd' 19 | lr: 1.0e-10 20 | weight_decay: 0.0005 21 | momentum: 0.99 22 | loss: 23 | name: 'cross_entropy' 24 | size_average: False 25 | lr_schedule: 26 | resume: fcn8s_pascal_best_model.pkl 27 | -------------------------------------------------------------------------------- /configs/frrnB_cityscapes.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: frrnB 3 | data: 4 | dataset: cityscapes 5 | train_split: train 6 | val_split: val 7 | img_rows: 512 8 | img_cols: 1024 9 | path: /private/home/meetshah/misc_code/ps/data/VOCdevkit/VOC2012/ 10 | training: 11 | train_iters: 85000 12 | batch_size: 2 13 | val_interval: 500 14 | print_interval: 25 15 | optimizer: 16 | lr: 1.0e-4 17 | l_rate: 1.0e-4 18 | l_schedule: 19 | momentum: 0.99 20 | weight_decay: 0.0005 21 | resume: frrnB_cityscapes_best_model.pkl 22 | visdom: False 23 | -------------------------------------------------------------------------------- /ptsemseg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meetps/pytorch-semseg/801fb200547caa5b0d91b8dde56b837da029f746/ptsemseg/__init__.py -------------------------------------------------------------------------------- /ptsemseg/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from ptsemseg.augmentations.augmentations import ( 3 | AdjustContrast, 4 | AdjustGamma, 5 | AdjustBrightness, 6 | AdjustSaturation, 7 | AdjustHue, 8 | RandomCrop, 9 | RandomHorizontallyFlip, 10 | RandomVerticallyFlip, 11 | Scale, 12 | RandomSized, 13 | RandomSizedCrop, 14 | RandomRotate, 15 | RandomTranslate, 16 | CenterCrop, 17 | Compose, 18 | ) 19 | 20 | logger = logging.getLogger("ptsemseg") 21 | 22 | key2aug = { 23 | "gamma": AdjustGamma, 24 | "hue": AdjustHue, 25 | "brightness": AdjustBrightness, 26 | "saturation": AdjustSaturation, 27 | "contrast": AdjustContrast, 28 | "rcrop": RandomCrop, 29 | "hflip": RandomHorizontallyFlip, 30 | "vflip": RandomVerticallyFlip, 31 | "scale": Scale, 32 | "rsize": RandomSized, 33 | "rsizecrop": RandomSizedCrop, 34 | "rotate": RandomRotate, 35 | "translate": RandomTranslate, 36 | "ccrop": CenterCrop, 37 | } 38 | 39 | 40 | def get_composed_augmentations(aug_dict): 41 | if aug_dict is None: 42 | logger.info("Using No Augmentations") 43 | return None 44 | 45 | augmentations = [] 46 | for aug_key, aug_param in aug_dict.items(): 47 | augmentations.append(key2aug[aug_key](aug_param)) 48 | logger.info("Using {} aug with params {}".format(aug_key, aug_param)) 49 | return Compose(augmentations) 50 | -------------------------------------------------------------------------------- /ptsemseg/augmentations/augmentations.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import random 4 | import numpy as np 5 | import torchvision.transforms.functional as tf 6 | 7 | from PIL import Image, ImageOps 8 | 9 | 10 | class Compose(object): 11 | def __init__(self, augmentations): 12 | self.augmentations = augmentations 13 | self.PIL2Numpy = False 14 | 15 | def __call__(self, img, mask): 16 | if isinstance(img, np.ndarray): 17 | img = Image.fromarray(img, mode="RGB") 18 | mask = Image.fromarray(mask, mode="L") 19 | self.PIL2Numpy = True 20 | 21 | assert img.size == mask.size 22 | for a in self.augmentations: 23 | img, mask = a(img, mask) 24 | 25 | if self.PIL2Numpy: 26 | img, mask = np.array(img), np.array(mask, dtype=np.uint8) 27 | 28 | return img, mask 29 | 30 | 31 | class RandomCrop(object): 32 | def __init__(self, size, padding=0): 33 | if isinstance(size, numbers.Number): 34 | self.size = (int(size), int(size)) 35 | else: 36 | self.size = size 37 | self.padding = padding 38 | 39 | def __call__(self, img, mask): 40 | if self.padding > 0: 41 | img = ImageOps.expand(img, border=self.padding, fill=0) 42 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 43 | 44 | assert img.size == mask.size 45 | w, h = img.size 46 | th, tw = self.size 47 | if w == tw and h == th: 48 | return img, mask 49 | if w < tw or h < th: 50 | return (img.resize((tw, th), Image.BILINEAR), mask.resize((tw, th), Image.NEAREST)) 51 | 52 | x1 = random.randint(0, w - tw) 53 | y1 = random.randint(0, h - th) 54 | return (img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th))) 55 | 56 | 57 | class AdjustGamma(object): 58 | def __init__(self, gamma): 59 | self.gamma = gamma 60 | 61 | def __call__(self, img, mask): 62 | assert img.size == mask.size 63 | return tf.adjust_gamma(img, random.uniform(1, 1 + self.gamma)), mask 64 | 65 | 66 | class AdjustSaturation(object): 67 | def __init__(self, saturation): 68 | self.saturation = saturation 69 | 70 | def __call__(self, img, mask): 71 | assert img.size == mask.size 72 | return ( 73 | tf.adjust_saturation(img, random.uniform(1 - self.saturation, 1 + self.saturation)), 74 | mask, 75 | ) 76 | 77 | 78 | class AdjustHue(object): 79 | def __init__(self, hue): 80 | self.hue = hue 81 | 82 | def __call__(self, img, mask): 83 | assert img.size == mask.size 84 | return tf.adjust_hue(img, random.uniform(-self.hue, self.hue)), mask 85 | 86 | 87 | class AdjustBrightness(object): 88 | def __init__(self, bf): 89 | self.bf = bf 90 | 91 | def __call__(self, img, mask): 92 | assert img.size == mask.size 93 | return tf.adjust_brightness(img, random.uniform(1 - self.bf, 1 + self.bf)), mask 94 | 95 | 96 | class AdjustContrast(object): 97 | def __init__(self, cf): 98 | self.cf = cf 99 | 100 | def __call__(self, img, mask): 101 | assert img.size == mask.size 102 | return tf.adjust_contrast(img, random.uniform(1 - self.cf, 1 + self.cf)), mask 103 | 104 | 105 | class CenterCrop(object): 106 | def __init__(self, size): 107 | if isinstance(size, numbers.Number): 108 | self.size = (int(size), int(size)) 109 | else: 110 | self.size = size 111 | 112 | def __call__(self, img, mask): 113 | assert img.size == mask.size 114 | w, h = img.size 115 | th, tw = self.size 116 | x1 = int(round((w - tw) / 2.0)) 117 | y1 = int(round((h - th) / 2.0)) 118 | return (img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th))) 119 | 120 | 121 | class RandomHorizontallyFlip(object): 122 | def __init__(self, p): 123 | self.p = p 124 | 125 | def __call__(self, img, mask): 126 | if random.random() < self.p: 127 | return (img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT)) 128 | return img, mask 129 | 130 | 131 | class RandomVerticallyFlip(object): 132 | def __init__(self, p): 133 | self.p = p 134 | 135 | def __call__(self, img, mask): 136 | if random.random() < self.p: 137 | return (img.transpose(Image.FLIP_TOP_BOTTOM), mask.transpose(Image.FLIP_TOP_BOTTOM)) 138 | return img, mask 139 | 140 | 141 | class FreeScale(object): 142 | def __init__(self, size): 143 | self.size = tuple(reversed(size)) # size: (h, w) 144 | 145 | def __call__(self, img, mask): 146 | assert img.size == mask.size 147 | return (img.resize(self.size, Image.BILINEAR), mask.resize(self.size, Image.NEAREST)) 148 | 149 | 150 | class RandomTranslate(object): 151 | def __init__(self, offset): 152 | # tuple (delta_x, delta_y) 153 | self.offset = offset 154 | 155 | def __call__(self, img, mask): 156 | assert img.size == mask.size 157 | x_offset = int(2 * (random.random() - 0.5) * self.offset[0]) 158 | y_offset = int(2 * (random.random() - 0.5) * self.offset[1]) 159 | 160 | x_crop_offset = x_offset 161 | y_crop_offset = y_offset 162 | if x_offset < 0: 163 | x_crop_offset = 0 164 | if y_offset < 0: 165 | y_crop_offset = 0 166 | 167 | cropped_img = tf.crop( 168 | img, 169 | y_crop_offset, 170 | x_crop_offset, 171 | img.size[1] - abs(y_offset), 172 | img.size[0] - abs(x_offset), 173 | ) 174 | 175 | if x_offset >= 0 and y_offset >= 0: 176 | padding_tuple = (0, 0, x_offset, y_offset) 177 | 178 | elif x_offset >= 0 and y_offset < 0: 179 | padding_tuple = (0, abs(y_offset), x_offset, 0) 180 | 181 | elif x_offset < 0 and y_offset >= 0: 182 | padding_tuple = (abs(x_offset), 0, 0, y_offset) 183 | 184 | elif x_offset < 0 and y_offset < 0: 185 | padding_tuple = (abs(x_offset), abs(y_offset), 0, 0) 186 | 187 | return ( 188 | tf.pad(cropped_img, padding_tuple, padding_mode="reflect"), 189 | tf.affine( 190 | mask, 191 | translate=(-x_offset, -y_offset), 192 | scale=1.0, 193 | angle=0.0, 194 | shear=0.0, 195 | fillcolor=250, 196 | ), 197 | ) 198 | 199 | 200 | class RandomRotate(object): 201 | def __init__(self, degree): 202 | self.degree = degree 203 | 204 | def __call__(self, img, mask): 205 | rotate_degree = random.random() * 2 * self.degree - self.degree 206 | return ( 207 | tf.affine( 208 | img, 209 | translate=(0, 0), 210 | scale=1.0, 211 | angle=rotate_degree, 212 | resample=Image.BILINEAR, 213 | fillcolor=(0, 0, 0), 214 | shear=0.0, 215 | ), 216 | tf.affine( 217 | mask, 218 | translate=(0, 0), 219 | scale=1.0, 220 | angle=rotate_degree, 221 | resample=Image.NEAREST, 222 | fillcolor=250, 223 | shear=0.0, 224 | ), 225 | ) 226 | 227 | 228 | class Scale(object): 229 | def __init__(self, size): 230 | self.size = size 231 | 232 | def __call__(self, img, mask): 233 | assert img.size == mask.size 234 | w, h = img.size 235 | if (w >= h and w == self.size) or (h >= w and h == self.size): 236 | return img, mask 237 | if w > h: 238 | ow = self.size 239 | oh = int(self.size * h / w) 240 | return (img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST)) 241 | else: 242 | oh = self.size 243 | ow = int(self.size * w / h) 244 | return (img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST)) 245 | 246 | 247 | class RandomSizedCrop(object): 248 | def __init__(self, size): 249 | self.size = size 250 | 251 | def __call__(self, img, mask): 252 | assert img.size == mask.size 253 | for attempt in range(10): 254 | area = img.size[0] * img.size[1] 255 | target_area = random.uniform(0.45, 1.0) * area 256 | aspect_ratio = random.uniform(0.5, 2) 257 | 258 | w = int(round(math.sqrt(target_area * aspect_ratio))) 259 | h = int(round(math.sqrt(target_area / aspect_ratio))) 260 | 261 | if random.random() < 0.5: 262 | w, h = h, w 263 | 264 | if w <= img.size[0] and h <= img.size[1]: 265 | x1 = random.randint(0, img.size[0] - w) 266 | y1 = random.randint(0, img.size[1] - h) 267 | 268 | img = img.crop((x1, y1, x1 + w, y1 + h)) 269 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 270 | assert img.size == (w, h) 271 | 272 | return ( 273 | img.resize((self.size, self.size), Image.BILINEAR), 274 | mask.resize((self.size, self.size), Image.NEAREST), 275 | ) 276 | 277 | # Fallback 278 | scale = Scale(self.size) 279 | crop = CenterCrop(self.size) 280 | return crop(*scale(img, mask)) 281 | 282 | 283 | class RandomSized(object): 284 | def __init__(self, size): 285 | self.size = size 286 | self.scale = Scale(self.size) 287 | self.crop = RandomCrop(self.size) 288 | 289 | def __call__(self, img, mask): 290 | assert img.size == mask.size 291 | 292 | w = int(random.uniform(0.5, 2) * img.size[0]) 293 | h = int(random.uniform(0.5, 2) * img.size[1]) 294 | 295 | img, mask = (img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST)) 296 | 297 | return self.crop(*self.scale(img, mask)) 298 | -------------------------------------------------------------------------------- /ptsemseg/loader/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from ptsemseg.loader.pascal_voc_loader import pascalVOCLoader 4 | from ptsemseg.loader.camvid_loader import camvidLoader 5 | from ptsemseg.loader.ade20k_loader import ADE20KLoader 6 | from ptsemseg.loader.mit_sceneparsing_benchmark_loader import MITSceneParsingBenchmarkLoader 7 | from ptsemseg.loader.cityscapes_loader import cityscapesLoader 8 | from ptsemseg.loader.nyuv2_loader import NYUv2Loader 9 | from ptsemseg.loader.sunrgbd_loader import SUNRGBDLoader 10 | from ptsemseg.loader.mapillary_vistas_loader import mapillaryVistasLoader 11 | 12 | 13 | def get_loader(name): 14 | """get_loader 15 | 16 | :param name: 17 | """ 18 | return { 19 | "pascal": pascalVOCLoader, 20 | "camvid": camvidLoader, 21 | "ade20k": ADE20KLoader, 22 | "mit_sceneparsing_benchmark": MITSceneParsingBenchmarkLoader, 23 | "cityscapes": cityscapesLoader, 24 | "nyuv2": NYUv2Loader, 25 | "sunrgbd": SUNRGBDLoader, 26 | "vistas": mapillaryVistasLoader, 27 | }[name] 28 | -------------------------------------------------------------------------------- /ptsemseg/loader/ade20k_loader.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | import torchvision 4 | import numpy as np 5 | import scipy.misc as m 6 | import matplotlib.pyplot as plt 7 | 8 | from torch.utils import data 9 | 10 | from ptsemseg.utils import recursive_glob 11 | 12 | 13 | class ADE20KLoader(data.Dataset): 14 | def __init__( 15 | self, 16 | root, 17 | split="training", 18 | is_transform=False, 19 | img_size=512, 20 | augmentations=None, 21 | img_norm=True, 22 | test_mode=False, 23 | ): 24 | self.root = root 25 | self.split = split 26 | self.is_transform = is_transform 27 | self.augmentations = augmentations 28 | self.img_norm = img_norm 29 | self.test_mode = test_mode 30 | self.n_classes = 150 31 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 32 | self.mean = np.array([104.00699, 116.66877, 122.67892]) 33 | self.files = collections.defaultdict(list) 34 | 35 | if not self.test_mode: 36 | for split in ["training", "validation"]: 37 | file_list = recursive_glob( 38 | rootdir=self.root + "images/" + self.split + "/", suffix=".jpg" 39 | ) 40 | self.files[split] = file_list 41 | 42 | def __len__(self): 43 | return len(self.files[self.split]) 44 | 45 | def __getitem__(self, index): 46 | img_path = self.files[self.split][index].rstrip() 47 | lbl_path = img_path[:-4] + "_seg.png" 48 | 49 | img = m.imread(img_path) 50 | img = np.array(img, dtype=np.uint8) 51 | 52 | lbl = m.imread(lbl_path) 53 | lbl = np.array(lbl, dtype=np.int32) 54 | 55 | if self.augmentations is not None: 56 | img, lbl = self.augmentations(img, lbl) 57 | 58 | if self.is_transform: 59 | img, lbl = self.transform(img, lbl) 60 | 61 | return img, lbl 62 | 63 | def transform(self, img, lbl): 64 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode 65 | img = img[:, :, ::-1] # RGB -> BGR 66 | img = img.astype(np.float64) 67 | img -= self.mean 68 | if self.img_norm: 69 | # Resize scales images from 0 to 255, thus we need 70 | # to divide by 255.0 71 | img = img.astype(float) / 255.0 72 | # NHWC -> NCHW 73 | img = img.transpose(2, 0, 1) 74 | 75 | lbl = self.encode_segmap(lbl) 76 | classes = np.unique(lbl) 77 | lbl = lbl.astype(float) 78 | lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F") 79 | lbl = lbl.astype(int) 80 | assert np.all(classes == np.unique(lbl)) 81 | 82 | img = torch.from_numpy(img).float() 83 | lbl = torch.from_numpy(lbl).long() 84 | return img, lbl 85 | 86 | def encode_segmap(self, mask): 87 | # Refer : http://groups.csail.mit.edu/vision/datasets/ADE20K/code/loadAde20K.m 88 | mask = mask.astype(int) 89 | label_mask = np.zeros((mask.shape[0], mask.shape[1])) 90 | label_mask = (mask[:, :, 0] / 10.0) * 256 + mask[:, :, 1] 91 | return np.array(label_mask, dtype=np.uint8) 92 | 93 | def decode_segmap(self, temp, plot=False): 94 | # TODO:(@meetshah1995) 95 | # Verify that the color mapping is 1-to-1 96 | r = temp.copy() 97 | g = temp.copy() 98 | b = temp.copy() 99 | for l in range(0, self.n_classes): 100 | r[temp == l] = 10 * (l % 10) 101 | g[temp == l] = l 102 | b[temp == l] = 0 103 | 104 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 105 | rgb[:, :, 0] = r / 255.0 106 | rgb[:, :, 1] = g / 255.0 107 | rgb[:, :, 2] = b / 255.0 108 | if plot: 109 | plt.imshow(rgb) 110 | plt.show() 111 | else: 112 | return rgb 113 | 114 | 115 | if __name__ == "__main__": 116 | local_path = "/Users/meet/data/ADE20K_2016_07_26/" 117 | dst = ADE20KLoader(local_path, is_transform=True) 118 | trainloader = data.DataLoader(dst, batch_size=4) 119 | for i, data_samples in enumerate(trainloader): 120 | imgs, labels = data_samples 121 | if i == 0: 122 | img = torchvision.utils.make_grid(imgs).numpy() 123 | img = np.transpose(img, (1, 2, 0)) 124 | img = img[:, :, ::-1] 125 | plt.imshow(img) 126 | plt.show() 127 | for j in range(4): 128 | plt.imshow(dst.decode_segmap(labels.numpy()[j])) 129 | plt.show() 130 | -------------------------------------------------------------------------------- /ptsemseg/loader/camvid_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import torch 4 | import numpy as np 5 | import scipy.misc as m 6 | import matplotlib.pyplot as plt 7 | 8 | from torch.utils import data 9 | from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate 10 | 11 | 12 | class camvidLoader(data.Dataset): 13 | def __init__( 14 | self, 15 | root, 16 | split="train", 17 | is_transform=False, 18 | img_size=None, 19 | augmentations=None, 20 | img_norm=True, 21 | test_mode=False, 22 | ): 23 | self.root = root 24 | self.split = split 25 | self.img_size = [360, 480] 26 | self.is_transform = is_transform 27 | self.augmentations = augmentations 28 | self.img_norm = img_norm 29 | self.test_mode = test_mode 30 | self.mean = np.array([104.00699, 116.66877, 122.67892]) 31 | self.n_classes = 12 32 | self.files = collections.defaultdict(list) 33 | 34 | if not self.test_mode: 35 | for split in ["train", "test", "val"]: 36 | file_list = os.listdir(root + "/" + split) 37 | self.files[split] = file_list 38 | 39 | def __len__(self): 40 | return len(self.files[self.split]) 41 | 42 | def __getitem__(self, index): 43 | img_name = self.files[self.split][index] 44 | img_path = self.root + "/" + self.split + "/" + img_name 45 | lbl_path = self.root + "/" + self.split + "annot/" + img_name 46 | 47 | img = m.imread(img_path) 48 | img = np.array(img, dtype=np.uint8) 49 | 50 | lbl = m.imread(lbl_path) 51 | lbl = np.array(lbl, dtype=np.int8) 52 | 53 | if self.augmentations is not None: 54 | img, lbl = self.augmentations(img, lbl) 55 | 56 | if self.is_transform: 57 | img, lbl = self.transform(img, lbl) 58 | 59 | return img, lbl 60 | 61 | def transform(self, img, lbl): 62 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode 63 | img = img[:, :, ::-1] # RGB -> BGR 64 | img = img.astype(np.float64) 65 | img -= self.mean 66 | if self.img_norm: 67 | # Resize scales images from 0 to 255, thus we need 68 | # to divide by 255.0 69 | img = img.astype(float) / 255.0 70 | # NHWC -> NCHW 71 | img = img.transpose(2, 0, 1) 72 | 73 | img = torch.from_numpy(img).float() 74 | lbl = torch.from_numpy(lbl).long() 75 | return img, lbl 76 | 77 | def decode_segmap(self, temp, plot=False): 78 | Sky = [128, 128, 128] 79 | Building = [128, 0, 0] 80 | Pole = [192, 192, 128] 81 | Road = [128, 64, 128] 82 | Pavement = [60, 40, 222] 83 | Tree = [128, 128, 0] 84 | SignSymbol = [192, 128, 128] 85 | Fence = [64, 64, 128] 86 | Car = [64, 0, 128] 87 | Pedestrian = [64, 64, 0] 88 | Bicyclist = [0, 128, 192] 89 | Unlabelled = [0, 0, 0] 90 | 91 | label_colours = np.array( 92 | [ 93 | Sky, 94 | Building, 95 | Pole, 96 | Road, 97 | Pavement, 98 | Tree, 99 | SignSymbol, 100 | Fence, 101 | Car, 102 | Pedestrian, 103 | Bicyclist, 104 | Unlabelled, 105 | ] 106 | ) 107 | r = temp.copy() 108 | g = temp.copy() 109 | b = temp.copy() 110 | for l in range(0, self.n_classes): 111 | r[temp == l] = label_colours[l, 0] 112 | g[temp == l] = label_colours[l, 1] 113 | b[temp == l] = label_colours[l, 2] 114 | 115 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 116 | rgb[:, :, 0] = r / 255.0 117 | rgb[:, :, 1] = g / 255.0 118 | rgb[:, :, 2] = b / 255.0 119 | return rgb 120 | 121 | 122 | if __name__ == "__main__": 123 | local_path = "/home/meetshah1995/datasets/segnet/CamVid" 124 | augmentations = Compose([RandomRotate(10), RandomHorizontallyFlip()]) 125 | 126 | dst = camvidLoader(local_path, is_transform=True, augmentations=augmentations) 127 | bs = 4 128 | trainloader = data.DataLoader(dst, batch_size=bs) 129 | for i, data_samples in enumerate(trainloader): 130 | imgs, labels = data_samples 131 | imgs = imgs.numpy()[:, ::-1, :, :] 132 | imgs = np.transpose(imgs, [0, 2, 3, 1]) 133 | f, axarr = plt.subplots(bs, 2) 134 | for j in range(bs): 135 | axarr[j][0].imshow(imgs[j]) 136 | axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j])) 137 | plt.show() 138 | a = input() 139 | if a == "ex": 140 | break 141 | else: 142 | plt.close() 143 | -------------------------------------------------------------------------------- /ptsemseg/loader/cityscapes_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import scipy.misc as m 5 | 6 | from torch.utils import data 7 | 8 | from ptsemseg.utils import recursive_glob 9 | from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, Scale 10 | 11 | 12 | class cityscapesLoader(data.Dataset): 13 | """cityscapesLoader 14 | 15 | https://www.cityscapes-dataset.com 16 | 17 | Data is derived from CityScapes, and can be downloaded from here: 18 | https://www.cityscapes-dataset.com/downloads/ 19 | 20 | Many Thanks to @fvisin for the loader repo: 21 | https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/cityscapes.py 22 | """ 23 | 24 | colors = [ # [ 0, 0, 0], 25 | [128, 64, 128], 26 | [244, 35, 232], 27 | [70, 70, 70], 28 | [102, 102, 156], 29 | [190, 153, 153], 30 | [153, 153, 153], 31 | [250, 170, 30], 32 | [220, 220, 0], 33 | [107, 142, 35], 34 | [152, 251, 152], 35 | [0, 130, 180], 36 | [220, 20, 60], 37 | [255, 0, 0], 38 | [0, 0, 142], 39 | [0, 0, 70], 40 | [0, 60, 100], 41 | [0, 80, 100], 42 | [0, 0, 230], 43 | [119, 11, 32], 44 | ] 45 | 46 | label_colours = dict(zip(range(19), colors)) 47 | 48 | mean_rgb = { 49 | "pascal": [103.939, 116.779, 123.68], 50 | "cityscapes": [0.0, 0.0, 0.0], 51 | } # pascal mean for PSPNet and ICNet pre-trained model 52 | 53 | def __init__( 54 | self, 55 | root, 56 | split="train", 57 | is_transform=False, 58 | img_size=(512, 1024), 59 | augmentations=None, 60 | img_norm=True, 61 | version="cityscapes", 62 | test_mode=False, 63 | ): 64 | """__init__ 65 | 66 | :param root: 67 | :param split: 68 | :param is_transform: 69 | :param img_size: 70 | :param augmentations 71 | """ 72 | self.root = root 73 | self.split = split 74 | self.is_transform = is_transform 75 | self.augmentations = augmentations 76 | self.img_norm = img_norm 77 | self.n_classes = 19 78 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 79 | self.mean = np.array(self.mean_rgb[version]) 80 | self.files = {} 81 | 82 | self.images_base = os.path.join(self.root, "leftImg8bit", self.split) 83 | self.annotations_base = os.path.join(self.root, "gtFine", self.split) 84 | 85 | self.files[split] = recursive_glob(rootdir=self.images_base, suffix=".png") 86 | 87 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 88 | self.valid_classes = [ 89 | 7, 90 | 8, 91 | 11, 92 | 12, 93 | 13, 94 | 17, 95 | 19, 96 | 20, 97 | 21, 98 | 22, 99 | 23, 100 | 24, 101 | 25, 102 | 26, 103 | 27, 104 | 28, 105 | 31, 106 | 32, 107 | 33, 108 | ] 109 | self.class_names = [ 110 | "unlabelled", 111 | "road", 112 | "sidewalk", 113 | "building", 114 | "wall", 115 | "fence", 116 | "pole", 117 | "traffic_light", 118 | "traffic_sign", 119 | "vegetation", 120 | "terrain", 121 | "sky", 122 | "person", 123 | "rider", 124 | "car", 125 | "truck", 126 | "bus", 127 | "train", 128 | "motorcycle", 129 | "bicycle", 130 | ] 131 | 132 | self.ignore_index = 250 133 | self.class_map = dict(zip(self.valid_classes, range(19))) 134 | 135 | if not self.files[split]: 136 | raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) 137 | 138 | print("Found %d %s images" % (len(self.files[split]), split)) 139 | 140 | def __len__(self): 141 | """__len__""" 142 | return len(self.files[self.split]) 143 | 144 | def __getitem__(self, index): 145 | """__getitem__ 146 | 147 | :param index: 148 | """ 149 | img_path = self.files[self.split][index].rstrip() 150 | lbl_path = os.path.join( 151 | self.annotations_base, 152 | img_path.split(os.sep)[-2], 153 | os.path.basename(img_path)[:-15] + "gtFine_labelIds.png", 154 | ) 155 | 156 | img = m.imread(img_path) 157 | img = np.array(img, dtype=np.uint8) 158 | 159 | lbl = m.imread(lbl_path) 160 | lbl = self.encode_segmap(np.array(lbl, dtype=np.uint8)) 161 | 162 | if self.augmentations is not None: 163 | img, lbl = self.augmentations(img, lbl) 164 | 165 | if self.is_transform: 166 | img, lbl = self.transform(img, lbl) 167 | 168 | return img, lbl 169 | 170 | def transform(self, img, lbl): 171 | """transform 172 | 173 | :param img: 174 | :param lbl: 175 | """ 176 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode 177 | img = img[:, :, ::-1] # RGB -> BGR 178 | img = img.astype(np.float64) 179 | img -= self.mean 180 | if self.img_norm: 181 | # Resize scales images from 0 to 255, thus we need 182 | # to divide by 255.0 183 | img = img.astype(float) / 255.0 184 | # NHWC -> NCHW 185 | img = img.transpose(2, 0, 1) 186 | 187 | classes = np.unique(lbl) 188 | lbl = lbl.astype(float) 189 | lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F") 190 | lbl = lbl.astype(int) 191 | 192 | if not np.all(classes == np.unique(lbl)): 193 | print("WARN: resizing labels yielded fewer classes") 194 | 195 | if not np.all(np.unique(lbl[lbl != self.ignore_index]) < self.n_classes): 196 | print("after det", classes, np.unique(lbl)) 197 | raise ValueError("Segmentation map contained invalid class values") 198 | 199 | img = torch.from_numpy(img).float() 200 | lbl = torch.from_numpy(lbl).long() 201 | 202 | return img, lbl 203 | 204 | def decode_segmap(self, temp): 205 | r = temp.copy() 206 | g = temp.copy() 207 | b = temp.copy() 208 | for l in range(0, self.n_classes): 209 | r[temp == l] = self.label_colours[l][0] 210 | g[temp == l] = self.label_colours[l][1] 211 | b[temp == l] = self.label_colours[l][2] 212 | 213 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 214 | rgb[:, :, 0] = r / 255.0 215 | rgb[:, :, 1] = g / 255.0 216 | rgb[:, :, 2] = b / 255.0 217 | return rgb 218 | 219 | def encode_segmap(self, mask): 220 | # Put all void classes to zero 221 | for _voidc in self.void_classes: 222 | mask[mask == _voidc] = self.ignore_index 223 | for _validc in self.valid_classes: 224 | mask[mask == _validc] = self.class_map[_validc] 225 | return mask 226 | 227 | 228 | if __name__ == "__main__": 229 | import matplotlib.pyplot as plt 230 | 231 | augmentations = Compose([Scale(2048), RandomRotate(10), RandomHorizontallyFlip(0.5)]) 232 | 233 | local_path = "/datasets01/cityscapes/112817/" 234 | dst = cityscapesLoader(local_path, is_transform=True, augmentations=augmentations) 235 | bs = 4 236 | trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0) 237 | for i, data_samples in enumerate(trainloader): 238 | imgs, labels = data_samples 239 | import pdb 240 | 241 | pdb.set_trace() 242 | imgs = imgs.numpy()[:, ::-1, :, :] 243 | imgs = np.transpose(imgs, [0, 2, 3, 1]) 244 | f, axarr = plt.subplots(bs, 2) 245 | for j in range(bs): 246 | axarr[j][0].imshow(imgs[j]) 247 | axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j])) 248 | plt.show() 249 | a = input() 250 | if a == "ex": 251 | break 252 | else: 253 | plt.close() 254 | -------------------------------------------------------------------------------- /ptsemseg/loader/mapillary_vistas_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import numpy as np 5 | 6 | from torch.utils import data 7 | from PIL import Image 8 | 9 | from ptsemseg.utils import recursive_glob 10 | from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate 11 | 12 | 13 | class mapillaryVistasLoader(data.Dataset): 14 | def __init__( 15 | self, 16 | root, 17 | split="training", 18 | img_size=(640, 1280), 19 | is_transform=True, 20 | augmentations=None, 21 | test_mode=False, 22 | ): 23 | self.root = root 24 | self.split = split 25 | self.is_transform = is_transform 26 | self.augmentations = augmentations 27 | self.n_classes = 65 28 | 29 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 30 | self.mean = np.array([80.5423, 91.3162, 81.4312]) 31 | self.files = {} 32 | 33 | self.images_base = os.path.join(self.root, self.split, "images") 34 | self.annotations_base = os.path.join(self.root, self.split, "labels") 35 | 36 | self.files[split] = recursive_glob(rootdir=self.images_base, suffix=".jpg") 37 | 38 | self.class_ids, self.class_names, self.class_colors = self.parse_config() 39 | 40 | self.ignore_id = 250 41 | 42 | if not self.files[split]: 43 | raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) 44 | 45 | print("Found %d %s images" % (len(self.files[split]), split)) 46 | 47 | def parse_config(self): 48 | with open(os.path.join(self.root, "config.json")) as config_file: 49 | config = json.load(config_file) 50 | 51 | labels = config["labels"] 52 | 53 | class_names = [] 54 | class_ids = [] 55 | class_colors = [] 56 | print("There are {} labels in the config file".format(len(labels))) 57 | for label_id, label in enumerate(labels): 58 | class_names.append(label["readable"]) 59 | class_ids.append(label_id) 60 | class_colors.append(label["color"]) 61 | 62 | return class_names, class_ids, class_colors 63 | 64 | def __len__(self): 65 | """__len__""" 66 | return len(self.files[self.split]) 67 | 68 | def __getitem__(self, index): 69 | """__getitem__ 70 | :param index: 71 | """ 72 | img_path = self.files[self.split][index].rstrip() 73 | lbl_path = os.path.join( 74 | self.annotations_base, os.path.basename(img_path).replace(".jpg", ".png") 75 | ) 76 | 77 | img = Image.open(img_path) 78 | lbl = Image.open(lbl_path) 79 | 80 | if self.augmentations is not None: 81 | img, lbl = self.augmentations(img, lbl) 82 | 83 | if self.is_transform: 84 | img, lbl = self.transform(img, lbl) 85 | 86 | return img, lbl 87 | 88 | def transform(self, img, lbl): 89 | if self.img_size == ("same", "same"): 90 | pass 91 | else: 92 | img = img.resize( 93 | (self.img_size[0], self.img_size[1]), resample=Image.LANCZOS 94 | ) # uint8 with RGB mode 95 | lbl = lbl.resize((self.img_size[0], self.img_size[1])) 96 | img = np.array(img).astype(np.float64) / 255.0 97 | img = torch.from_numpy(img.transpose(2, 0, 1)).float() # From HWC to CHW 98 | lbl = torch.from_numpy(np.array(lbl)).long() 99 | lbl[lbl == 65] = self.ignore_id 100 | return img, lbl 101 | 102 | def decode_segmap(self, temp): 103 | r = temp.copy() 104 | g = temp.copy() 105 | b = temp.copy() 106 | for l in range(0, self.n_classes): 107 | r[temp == l] = self.class_colors[l][0] 108 | g[temp == l] = self.class_colors[l][1] 109 | b[temp == l] = self.class_colors[l][2] 110 | 111 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 112 | rgb[:, :, 0] = r / 255.0 113 | rgb[:, :, 1] = g / 255.0 114 | rgb[:, :, 2] = b / 255.0 115 | return rgb 116 | 117 | 118 | if __name__ == "__main__": 119 | augment = Compose([RandomHorizontallyFlip(), RandomRotate(6)]) 120 | 121 | local_path = "/private/home/meetshah/datasets/seg/vistas/" 122 | dst = mapillaryVistasLoader( 123 | local_path, img_size=(512, 1024), is_transform=True, augmentations=augment 124 | ) 125 | bs = 8 126 | trainloader = data.DataLoader(dst, batch_size=bs, num_workers=4, shuffle=True) 127 | for i, data_samples in enumerate(trainloader): 128 | x = dst.decode_segmap(data_samples[1][0].numpy()) 129 | print("batch :", i) 130 | -------------------------------------------------------------------------------- /ptsemseg/loader/mit_sceneparsing_benchmark_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import scipy.misc as m 5 | 6 | from torch.utils import data 7 | 8 | from ptsemseg.utils import recursive_glob 9 | 10 | 11 | class MITSceneParsingBenchmarkLoader(data.Dataset): 12 | """MITSceneParsingBenchmarkLoader 13 | 14 | http://sceneparsing.csail.mit.edu/ 15 | 16 | Data is derived from ADE20k, and can be downloaded from here: 17 | http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip 18 | 19 | NOTE: this loader is not designed to work with the original ADE20k dataset; 20 | for that you will need the ADE20kLoader 21 | 22 | This class can also be extended to load data for places challenge: 23 | https://github.com/CSAILVision/placeschallenge/tree/master/sceneparsing 24 | 25 | """ 26 | 27 | def __init__( 28 | self, 29 | root, 30 | split="training", 31 | is_transform=False, 32 | img_size=512, 33 | augmentations=None, 34 | img_norm=True, 35 | test_mode=False, 36 | ): 37 | """__init__ 38 | 39 | :param root: 40 | :param split: 41 | :param is_transform: 42 | :param img_size: 43 | """ 44 | self.root = root 45 | self.split = split 46 | self.is_transform = is_transform 47 | self.augmentations = augmentations 48 | self.img_norm = img_norm 49 | self.n_classes = 151 # 0 is reserved for "other" 50 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 51 | self.mean = np.array([104.00699, 116.66877, 122.67892]) 52 | self.files = {} 53 | 54 | self.images_base = os.path.join(self.root, "images", self.split) 55 | self.annotations_base = os.path.join(self.root, "annotations", self.split) 56 | 57 | self.files[split] = recursive_glob(rootdir=self.images_base, suffix=".jpg") 58 | 59 | if not self.files[split]: 60 | raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) 61 | 62 | print("Found %d %s images" % (len(self.files[split]), split)) 63 | 64 | def __len__(self): 65 | """__len__""" 66 | return len(self.files[self.split]) 67 | 68 | def __getitem__(self, index): 69 | """__getitem__ 70 | 71 | :param index: 72 | """ 73 | img_path = self.files[self.split][index].rstrip() 74 | lbl_path = os.path.join(self.annotations_base, os.path.basename(img_path)[:-4] + ".png") 75 | 76 | img = m.imread(img_path, mode="RGB") 77 | img = np.array(img, dtype=np.uint8) 78 | 79 | lbl = m.imread(lbl_path) 80 | lbl = np.array(lbl, dtype=np.uint8) 81 | 82 | if self.augmentations is not None: 83 | img, lbl = self.augmentations(img, lbl) 84 | 85 | if self.is_transform: 86 | img, lbl = self.transform(img, lbl) 87 | 88 | return img, lbl 89 | 90 | def transform(self, img, lbl): 91 | """transform 92 | 93 | :param img: 94 | :param lbl: 95 | """ 96 | if self.img_size == ("same", "same"): 97 | pass 98 | else: 99 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode 100 | img = img[:, :, ::-1] # RGB -> BGR 101 | img = img.astype(np.float64) 102 | img -= self.mean 103 | if self.img_norm: 104 | # Resize scales images from 0 to 255, thus we need 105 | # to divide by 255.0 106 | img = img.astype(float) / 255.0 107 | # NHWC -> NCHW 108 | img = img.transpose(2, 0, 1) 109 | 110 | classes = np.unique(lbl) 111 | lbl = lbl.astype(float) 112 | if self.img_size == ("same", "same"): 113 | pass 114 | else: 115 | lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F") 116 | lbl = lbl.astype(int) 117 | 118 | if not np.all(classes == np.unique(lbl)): 119 | print("WARN: resizing labels yielded fewer classes") 120 | 121 | if not np.all(np.unique(lbl) < self.n_classes): 122 | raise ValueError("Segmentation map contained invalid class values") 123 | 124 | img = torch.from_numpy(img).float() 125 | lbl = torch.from_numpy(lbl).long() 126 | 127 | return img, lbl 128 | -------------------------------------------------------------------------------- /ptsemseg/loader/nyuv2_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import torch 4 | import numpy as np 5 | import scipy.misc as m 6 | 7 | from torch.utils import data 8 | 9 | from ptsemseg.utils import recursive_glob 10 | from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, Scale 11 | 12 | 13 | class NYUv2Loader(data.Dataset): 14 | """ 15 | NYUv2 loader 16 | Download From (only 13 classes): 17 | test source: http://www.doc.ic.ac.uk/~ahanda/nyu_test_rgb.tgz 18 | train source: http://www.doc.ic.ac.uk/~ahanda/nyu_train_rgb.tgz 19 | test_labels source: 20 | https://github.com/ankurhanda/nyuv2-meta-data/raw/master/test_labels_13/nyuv2_test_class13.tgz 21 | train_labels source: 22 | https://github.com/ankurhanda/nyuv2-meta-data/raw/master/train_labels_13/nyuv2_train_class13.tgz 23 | 24 | """ 25 | 26 | def __init__( 27 | self, 28 | root, 29 | split="training", 30 | is_transform=False, 31 | img_size=(480, 640), 32 | augmentations=None, 33 | img_norm=True, 34 | test_mode=False, 35 | ): 36 | self.root = root 37 | self.is_transform = is_transform 38 | self.n_classes = 14 39 | self.augmentations = augmentations 40 | self.img_norm = img_norm 41 | self.test_mode = test_mode 42 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 43 | self.mean = np.array([104.00699, 116.66877, 122.67892]) 44 | self.files = collections.defaultdict(list) 45 | self.cmap = self.color_map(normalized=False) 46 | 47 | split_map = {"training": "train", "val": "test"} 48 | self.split = split_map[split] 49 | 50 | for split in ["train", "test"]: 51 | file_list = recursive_glob(rootdir=self.root + split + "/", suffix="png") 52 | self.files[split] = file_list 53 | 54 | def __len__(self): 55 | return len(self.files[self.split]) 56 | 57 | def __getitem__(self, index): 58 | img_path = self.files[self.split][index].rstrip() 59 | img_number = img_path.split("_")[-1][:4] 60 | lbl_path = os.path.join( 61 | self.root, self.split + "_annot", "new_nyu_class13_" + img_number + ".png" 62 | ) 63 | 64 | img = m.imread(img_path) 65 | img = np.array(img, dtype=np.uint8) 66 | 67 | lbl = m.imread(lbl_path) 68 | lbl = np.array(lbl, dtype=np.uint8) 69 | 70 | if not (len(img.shape) == 3 and len(lbl.shape) == 2): 71 | return self.__getitem__(np.random.randint(0, self.__len__())) 72 | 73 | if self.augmentations is not None: 74 | img, lbl = self.augmentations(img, lbl) 75 | 76 | if self.is_transform: 77 | img, lbl = self.transform(img, lbl) 78 | 79 | return img, lbl 80 | 81 | def transform(self, img, lbl): 82 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode 83 | img = img[:, :, ::-1] # RGB -> BGR 84 | img = img.astype(np.float64) 85 | img -= self.mean 86 | if self.img_norm: 87 | # Resize scales images from 0 to 255, thus we need 88 | # to divide by 255.0 89 | img = img.astype(float) / 255.0 90 | # NHWC -> NCHW 91 | img = img.transpose(2, 0, 1) 92 | 93 | classes = np.unique(lbl) 94 | lbl = lbl.astype(float) 95 | lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F") 96 | lbl = lbl.astype(int) 97 | assert np.all(classes == np.unique(lbl)) 98 | 99 | img = torch.from_numpy(img).float() 100 | lbl = torch.from_numpy(lbl).long() 101 | return img, lbl 102 | 103 | def color_map(self, N=256, normalized=False): 104 | """ 105 | Return Color Map in PASCAL VOC format 106 | """ 107 | 108 | def bitget(byteval, idx): 109 | return (byteval & (1 << idx)) != 0 110 | 111 | dtype = "float32" if normalized else "uint8" 112 | cmap = np.zeros((N, 3), dtype=dtype) 113 | for i in range(N): 114 | r = g = b = 0 115 | c = i 116 | for j in range(8): 117 | r = r | (bitget(c, 0) << 7 - j) 118 | g = g | (bitget(c, 1) << 7 - j) 119 | b = b | (bitget(c, 2) << 7 - j) 120 | c = c >> 3 121 | 122 | cmap[i] = np.array([r, g, b]) 123 | 124 | cmap = cmap / 255.0 if normalized else cmap 125 | return cmap 126 | 127 | def decode_segmap(self, temp): 128 | r = temp.copy() 129 | g = temp.copy() 130 | b = temp.copy() 131 | for l in range(0, self.n_classes): 132 | r[temp == l] = self.cmap[l, 0] 133 | g[temp == l] = self.cmap[l, 1] 134 | b[temp == l] = self.cmap[l, 2] 135 | 136 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 137 | rgb[:, :, 0] = r / 255.0 138 | rgb[:, :, 1] = g / 255.0 139 | rgb[:, :, 2] = b / 255.0 140 | return rgb 141 | 142 | 143 | if __name__ == "__main__": 144 | import matplotlib.pyplot as plt 145 | 146 | augmentations = Compose([Scale(512), RandomRotate(10), RandomHorizontallyFlip()]) 147 | 148 | local_path = "/home/meet/datasets/NYUv2/" 149 | dst = NYUv2Loader(local_path, is_transform=True, augmentations=augmentations) 150 | bs = 4 151 | trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0) 152 | for i, datas in enumerate(trainloader): 153 | imgs, labels = datas 154 | imgs = imgs.numpy()[:, ::-1, :, :] 155 | imgs = np.transpose(imgs, [0, 2, 3, 1]) 156 | f, axarr = plt.subplots(bs, 2) 157 | for j in range(bs): 158 | axarr[j][0].imshow(imgs[j]) 159 | axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j])) 160 | plt.show() 161 | a = input() 162 | if a == "ex": 163 | break 164 | else: 165 | plt.close() 166 | -------------------------------------------------------------------------------- /ptsemseg/loader/pascal_voc_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as pjoin 3 | import collections 4 | import json 5 | import torch 6 | import numpy as np 7 | import scipy.misc as m 8 | import scipy.io as io 9 | import matplotlib.pyplot as plt 10 | import glob 11 | 12 | from PIL import Image 13 | from tqdm import tqdm 14 | from torch.utils import data 15 | from torchvision import transforms 16 | 17 | 18 | class pascalVOCLoader(data.Dataset): 19 | """Data loader for the Pascal VOC semantic segmentation dataset. 20 | 21 | Annotations from both the original VOC data (which consist of RGB images 22 | in which colours map to specific classes) and the SBD (Berkely) dataset 23 | (where annotations are stored as .mat files) are converted into a common 24 | `label_mask` format. Under this format, each mask is an (M,N) array of 25 | integer values from 0 to 21, where 0 represents the background class. 26 | 27 | The label masks are stored in a new folder, called `pre_encoded`, which 28 | is added as a subdirectory of the `SegmentationClass` folder in the 29 | original Pascal VOC data layout. 30 | 31 | A total of five data splits are provided for working with the VOC data: 32 | train: The original VOC 2012 training data - 1464 images 33 | val: The original VOC 2012 validation data - 1449 images 34 | trainval: The combination of `train` and `val` - 2913 images 35 | train_aug: The unique images present in both the train split and 36 | training images from SBD: - 8829 images (the unique members 37 | of the result of combining lists of length 1464 and 8498) 38 | train_aug_val: The original VOC 2012 validation data minus the images 39 | present in `train_aug` (This is done with the same logic as 40 | the validation set used in FCN PAMI paper, but with VOC 2012 41 | rather than VOC 2011) - 904 images 42 | """ 43 | 44 | def __init__( 45 | self, 46 | root, 47 | sbd_path=None, 48 | split="train_aug", 49 | is_transform=False, 50 | img_size=512, 51 | augmentations=None, 52 | img_norm=True, 53 | test_mode=False, 54 | ): 55 | self.root = root 56 | self.sbd_path = sbd_path 57 | self.split = split 58 | self.is_transform = is_transform 59 | self.augmentations = augmentations 60 | self.img_norm = img_norm 61 | self.test_mode = test_mode 62 | self.n_classes = 21 63 | self.mean = np.array([104.00699, 116.66877, 122.67892]) 64 | self.files = collections.defaultdict(list) 65 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 66 | 67 | if not self.test_mode: 68 | for split in ["train", "val", "trainval"]: 69 | path = pjoin(self.root, "ImageSets/Segmentation", split + ".txt") 70 | file_list = tuple(open(path, "r")) 71 | file_list = [id_.rstrip() for id_ in file_list] 72 | self.files[split] = file_list 73 | self.setup_annotations() 74 | 75 | self.tf = transforms.Compose( 76 | [ 77 | transforms.ToTensor(), 78 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 79 | ] 80 | ) 81 | 82 | def __len__(self): 83 | return len(self.files[self.split]) 84 | 85 | def __getitem__(self, index): 86 | im_name = self.files[self.split][index] 87 | im_path = pjoin(self.root, "JPEGImages", im_name + ".jpg") 88 | lbl_path = pjoin(self.root, "SegmentationClass/pre_encoded", im_name + ".png") 89 | im = Image.open(im_path) 90 | lbl = Image.open(lbl_path) 91 | if self.augmentations is not None: 92 | im, lbl = self.augmentations(im, lbl) 93 | if self.is_transform: 94 | im, lbl = self.transform(im, lbl) 95 | return im, lbl 96 | 97 | def transform(self, img, lbl): 98 | if self.img_size == ("same", "same"): 99 | pass 100 | else: 101 | img = img.resize((self.img_size[0], self.img_size[1])) # uint8 with RGB mode 102 | lbl = lbl.resize((self.img_size[0], self.img_size[1])) 103 | img = self.tf(img) 104 | lbl = torch.from_numpy(np.array(lbl)).long() 105 | lbl[lbl == 255] = 0 106 | return img, lbl 107 | 108 | def get_pascal_labels(self): 109 | """Load the mapping that associates pascal classes with label colors 110 | 111 | Returns: 112 | np.ndarray with dimensions (21, 3) 113 | """ 114 | return np.asarray( 115 | [ 116 | [0, 0, 0], 117 | [128, 0, 0], 118 | [0, 128, 0], 119 | [128, 128, 0], 120 | [0, 0, 128], 121 | [128, 0, 128], 122 | [0, 128, 128], 123 | [128, 128, 128], 124 | [64, 0, 0], 125 | [192, 0, 0], 126 | [64, 128, 0], 127 | [192, 128, 0], 128 | [64, 0, 128], 129 | [192, 0, 128], 130 | [64, 128, 128], 131 | [192, 128, 128], 132 | [0, 64, 0], 133 | [128, 64, 0], 134 | [0, 192, 0], 135 | [128, 192, 0], 136 | [0, 64, 128], 137 | ] 138 | ) 139 | 140 | def encode_segmap(self, mask): 141 | """Encode segmentation label images as pascal classes 142 | 143 | Args: 144 | mask (np.ndarray): raw segmentation label image of dimension 145 | (M, N, 3), in which the Pascal classes are encoded as colours. 146 | 147 | Returns: 148 | (np.ndarray): class map with dimensions (M,N), where the value at 149 | a given location is the integer denoting the class index. 150 | """ 151 | mask = mask.astype(int) 152 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 153 | for ii, label in enumerate(self.get_pascal_labels()): 154 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 155 | label_mask = label_mask.astype(int) 156 | return label_mask 157 | 158 | def decode_segmap(self, label_mask, plot=False): 159 | """Decode segmentation class labels into a color image 160 | 161 | Args: 162 | label_mask (np.ndarray): an (M,N) array of integer values denoting 163 | the class label at each spatial location. 164 | plot (bool, optional): whether to show the resulting color image 165 | in a figure. 166 | 167 | Returns: 168 | (np.ndarray, optional): the resulting decoded color image. 169 | """ 170 | label_colours = self.get_pascal_labels() 171 | r = label_mask.copy() 172 | g = label_mask.copy() 173 | b = label_mask.copy() 174 | for ll in range(0, self.n_classes): 175 | r[label_mask == ll] = label_colours[ll, 0] 176 | g[label_mask == ll] = label_colours[ll, 1] 177 | b[label_mask == ll] = label_colours[ll, 2] 178 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 179 | rgb[:, :, 0] = r / 255.0 180 | rgb[:, :, 1] = g / 255.0 181 | rgb[:, :, 2] = b / 255.0 182 | if plot: 183 | plt.imshow(rgb) 184 | plt.show() 185 | else: 186 | return rgb 187 | 188 | def setup_annotations(self): 189 | """Sets up Berkley annotations by adding image indices to the 190 | `train_aug` split and pre-encode all segmentation labels into the 191 | common label_mask format (if this has not already been done). This 192 | function also defines the `train_aug` and `train_aug_val` data splits 193 | according to the description in the class docstring 194 | """ 195 | sbd_path = self.sbd_path 196 | target_path = pjoin(self.root, "SegmentationClass/pre_encoded") 197 | if not os.path.exists(target_path): 198 | os.makedirs(target_path) 199 | path = pjoin(sbd_path, "dataset/train.txt") 200 | sbd_train_list = tuple(open(path, "r")) 201 | sbd_train_list = [id_.rstrip() for id_ in sbd_train_list] 202 | train_aug = self.files["train"] + sbd_train_list 203 | 204 | # keep unique elements (stable) 205 | train_aug = [train_aug[i] for i in sorted(np.unique(train_aug, return_index=True)[1])] 206 | self.files["train_aug"] = train_aug 207 | set_diff = set(self.files["val"]) - set(train_aug) # remove overlap 208 | self.files["train_aug_val"] = list(set_diff) 209 | 210 | pre_encoded = glob.glob(pjoin(target_path, "*.png")) 211 | expected = np.unique(self.files["train_aug"] + self.files["val"]).size 212 | 213 | if len(pre_encoded) != expected: 214 | print("Pre-encoding segmentation masks...") 215 | for ii in tqdm(sbd_train_list): 216 | lbl_path = pjoin(sbd_path, "dataset/cls", ii + ".mat") 217 | data = io.loadmat(lbl_path) 218 | lbl = data["GTcls"][0]["Segmentation"][0].astype(np.int32) 219 | lbl = m.toimage(lbl, high=lbl.max(), low=lbl.min()) 220 | m.imsave(pjoin(target_path, ii + ".png"), lbl) 221 | 222 | for ii in tqdm(self.files["trainval"]): 223 | fname = ii + ".png" 224 | lbl_path = pjoin(self.root, "SegmentationClass", fname) 225 | lbl = self.encode_segmap(m.imread(lbl_path)) 226 | lbl = m.toimage(lbl, high=lbl.max(), low=lbl.min()) 227 | m.imsave(pjoin(target_path, fname), lbl) 228 | 229 | assert expected == 9733, "unexpected dataset sizes" 230 | 231 | 232 | # Leave code for debugging purposes 233 | # import ptsemseg.augmentations as aug 234 | # if __name__ == '__main__': 235 | # # local_path = '/home/meetshah1995/datasets/VOCdevkit/VOC2012/' 236 | # bs = 4 237 | # augs = aug.Compose([aug.RandomRotate(10), aug.RandomHorizontallyFlip()]) 238 | # dst = pascalVOCLoader(root=local_path, is_transform=True, augmentations=augs) 239 | # trainloader = data.DataLoader(dst, batch_size=bs) 240 | # for i, data in enumerate(trainloader): 241 | # imgs, labels = data 242 | # imgs = imgs.numpy()[:, ::-1, :, :] 243 | # imgs = np.transpose(imgs, [0,2,3,1]) 244 | # f, axarr = plt.subplots(bs, 2) 245 | # for j in range(bs): 246 | # axarr[j][0].imshow(imgs[j]) 247 | # axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j])) 248 | # plt.show() 249 | # a = raw_input() 250 | # if a == 'ex': 251 | # break 252 | # else: 253 | # plt.close() 254 | -------------------------------------------------------------------------------- /ptsemseg/loader/sunrgbd_loader.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | import numpy as np 4 | import scipy.misc as m 5 | 6 | from torch.utils import data 7 | 8 | from ptsemseg.utils import recursive_glob 9 | from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, Scale 10 | 11 | 12 | class SUNRGBDLoader(data.Dataset): 13 | """SUNRGBD loader 14 | 15 | Download From: 16 | http://www.doc.ic.ac.uk/~ahanda/SUNRGBD-test_images.tgz 17 | test source: http://www.doc.ic.ac.uk/~ahanda/SUNRGBD-test_images.tgz 18 | train source: http://www.doc.ic.ac.uk/~ahanda/SUNRGBD-train_images.tgz 19 | 20 | first 5050 in this is test, later 5051 is train 21 | test and train labels source: 22 | https://github.com/ankurhanda/sunrgbd-meta-data/raw/master/sunrgbd_train_test_labels.tar.gz 23 | """ 24 | 25 | def __init__( 26 | self, 27 | root, 28 | split="training", 29 | is_transform=False, 30 | img_size=(480, 640), 31 | augmentations=None, 32 | img_norm=True, 33 | test_mode=False, 34 | ): 35 | self.root = root 36 | self.is_transform = is_transform 37 | self.n_classes = 38 38 | self.augmentations = augmentations 39 | self.img_norm = img_norm 40 | self.test_mode = test_mode 41 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 42 | self.mean = np.array([104.00699, 116.66877, 122.67892]) 43 | self.files = collections.defaultdict(list) 44 | self.anno_files = collections.defaultdict(list) 45 | self.cmap = self.color_map(normalized=False) 46 | 47 | split_map = {"training": "train", "val": "test"} 48 | self.split = split_map[split] 49 | 50 | for split in ["train", "test"]: 51 | file_list = sorted(recursive_glob(rootdir=self.root + split + "/", suffix="jpg")) 52 | self.files[split] = file_list 53 | 54 | for split in ["train", "test"]: 55 | file_list = sorted( 56 | recursive_glob(rootdir=self.root + "annotations/" + split + "/", suffix="png") 57 | ) 58 | self.anno_files[split] = file_list 59 | 60 | def __len__(self): 61 | return len(self.files[self.split]) 62 | 63 | def __getitem__(self, index): 64 | img_path = self.files[self.split][index].rstrip() 65 | lbl_path = self.anno_files[self.split][index].rstrip() 66 | # img_number = img_path.split('/')[-1] 67 | # lbl_path = os.path.join(self.root, 'annotations', img_number).replace('jpg', 'png') 68 | 69 | img = m.imread(img_path) 70 | img = np.array(img, dtype=np.uint8) 71 | 72 | lbl = m.imread(lbl_path) 73 | lbl = np.array(lbl, dtype=np.uint8) 74 | 75 | if not (len(img.shape) == 3 and len(lbl.shape) == 2): 76 | return self.__getitem__(np.random.randint(0, self.__len__())) 77 | 78 | if self.augmentations is not None: 79 | img, lbl = self.augmentations(img, lbl) 80 | 81 | if self.is_transform: 82 | img, lbl = self.transform(img, lbl) 83 | 84 | return img, lbl 85 | 86 | def transform(self, img, lbl): 87 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode 88 | img = img[:, :, ::-1] # RGB -> BGR 89 | img = img.astype(np.float64) 90 | img -= self.mean 91 | if self.img_norm: 92 | # Resize scales images from 0 to 255, thus we need 93 | # to divide by 255.0 94 | img = img.astype(float) / 255.0 95 | # NHWC -> NCHW 96 | img = img.transpose(2, 0, 1) 97 | 98 | classes = np.unique(lbl) 99 | lbl = lbl.astype(float) 100 | lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F") 101 | lbl = lbl.astype(int) 102 | assert np.all(classes == np.unique(lbl)) 103 | 104 | img = torch.from_numpy(img).float() 105 | lbl = torch.from_numpy(lbl).long() 106 | return img, lbl 107 | 108 | def color_map(self, N=256, normalized=False): 109 | """ 110 | Return Color Map in PASCAL VOC format 111 | """ 112 | 113 | def bitget(byteval, idx): 114 | return (byteval & (1 << idx)) != 0 115 | 116 | dtype = "float32" if normalized else "uint8" 117 | cmap = np.zeros((N, 3), dtype=dtype) 118 | for i in range(N): 119 | r = g = b = 0 120 | c = i 121 | for j in range(8): 122 | r = r | (bitget(c, 0) << 7 - j) 123 | g = g | (bitget(c, 1) << 7 - j) 124 | b = b | (bitget(c, 2) << 7 - j) 125 | c = c >> 3 126 | 127 | cmap[i] = np.array([r, g, b]) 128 | 129 | cmap = cmap / 255.0 if normalized else cmap 130 | return cmap 131 | 132 | def decode_segmap(self, temp): 133 | r = temp.copy() 134 | g = temp.copy() 135 | b = temp.copy() 136 | for l in range(0, self.n_classes): 137 | r[temp == l] = self.cmap[l, 0] 138 | g[temp == l] = self.cmap[l, 1] 139 | b[temp == l] = self.cmap[l, 2] 140 | 141 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 142 | rgb[:, :, 0] = r / 255.0 143 | rgb[:, :, 1] = g / 255.0 144 | rgb[:, :, 2] = b / 255.0 145 | return rgb 146 | 147 | 148 | if __name__ == "__main__": 149 | import matplotlib.pyplot as plt 150 | 151 | augmentations = Compose([Scale(512), RandomRotate(10), RandomHorizontallyFlip()]) 152 | 153 | local_path = "/home/meet/datasets/SUNRGBD/" 154 | dst = SUNRGBDLoader(local_path, is_transform=True, augmentations=augmentations) 155 | bs = 4 156 | trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0) 157 | for i, data_samples in enumerate(trainloader): 158 | imgs, labels = data_samples 159 | imgs = imgs.numpy()[:, ::-1, :, :] 160 | imgs = np.transpose(imgs, [0, 2, 3, 1]) 161 | f, axarr = plt.subplots(bs, 2) 162 | for j in range(bs): 163 | axarr[j][0].imshow(imgs[j]) 164 | axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j])) 165 | plt.show() 166 | a = input() 167 | if a == "ex": 168 | break 169 | else: 170 | plt.close() 171 | -------------------------------------------------------------------------------- /ptsemseg/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import functools 3 | 4 | from ptsemseg.loss.loss import ( 5 | cross_entropy2d, 6 | bootstrapped_cross_entropy2d, 7 | multi_scale_cross_entropy2d, 8 | ) 9 | 10 | 11 | logger = logging.getLogger("ptsemseg") 12 | 13 | key2loss = { 14 | "cross_entropy": cross_entropy2d, 15 | "bootstrapped_cross_entropy": bootstrapped_cross_entropy2d, 16 | "multi_scale_cross_entropy": multi_scale_cross_entropy2d, 17 | } 18 | 19 | 20 | def get_loss_function(cfg): 21 | if cfg["training"]["loss"] is None: 22 | logger.info("Using default cross entropy loss") 23 | return cross_entropy2d 24 | 25 | else: 26 | loss_dict = cfg["training"]["loss"] 27 | loss_name = loss_dict["name"] 28 | loss_params = {k: v for k, v in loss_dict.items() if k != "name"} 29 | 30 | if loss_name not in key2loss: 31 | raise NotImplementedError("Loss {} not implemented".format(loss_name)) 32 | 33 | logger.info("Using {} with {} params".format(loss_name, loss_params)) 34 | return functools.partial(key2loss[loss_name], **loss_params) 35 | -------------------------------------------------------------------------------- /ptsemseg/loss/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def cross_entropy2d(input, target, weight=None, size_average=True): 6 | n, c, h, w = input.size() 7 | nt, ht, wt = target.size() 8 | 9 | # Handle inconsistent size between input and target 10 | if h != ht and w != wt: # upsample labels 11 | input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True) 12 | 13 | input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 14 | target = target.view(-1) 15 | loss = F.cross_entropy( 16 | input, target, weight=weight, size_average=size_average, ignore_index=250 17 | ) 18 | return loss 19 | 20 | 21 | def multi_scale_cross_entropy2d(input, target, weight=None, size_average=True, scale_weight=None): 22 | if not isinstance(input, tuple): 23 | return cross_entropy2d(input=input, target=target, weight=weight, size_average=size_average) 24 | 25 | # Auxiliary training for PSPNet [1.0, 0.4] and ICNet [1.0, 0.4, 0.16] 26 | if scale_weight is None: # scale_weight: torch tensor type 27 | n_inp = len(input) 28 | scale = 0.4 29 | scale_weight = torch.pow(scale * torch.ones(n_inp), torch.arange(n_inp).float()).to( 30 | target.device 31 | ) 32 | 33 | loss = 0.0 34 | for i, inp in enumerate(input): 35 | loss = loss + scale_weight[i] * cross_entropy2d( 36 | input=inp, target=target, weight=weight, size_average=size_average 37 | ) 38 | 39 | return loss 40 | 41 | 42 | def bootstrapped_cross_entropy2d(input, target, K, weight=None, size_average=True): 43 | 44 | batch_size = input.size()[0] 45 | 46 | def _bootstrap_xentropy_single(input, target, K, weight=None, size_average=True): 47 | 48 | n, c, h, w = input.size() 49 | input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 50 | target = target.view(-1) 51 | loss = F.cross_entropy( 52 | input, target, weight=weight, reduce=False, size_average=False, ignore_index=250 53 | ) 54 | 55 | topk_loss, _ = loss.topk(K) 56 | reduced_topk_loss = topk_loss.sum() / K 57 | 58 | return reduced_topk_loss 59 | 60 | loss = 0.0 61 | # Bootstrap from each image not entire batch 62 | for i in range(batch_size): 63 | loss += _bootstrap_xentropy_single( 64 | input=torch.unsqueeze(input[i], 0), 65 | target=torch.unsqueeze(target[i], 0), 66 | K=K, 67 | weight=weight, 68 | size_average=size_average, 69 | ) 70 | return loss / float(batch_size) 71 | -------------------------------------------------------------------------------- /ptsemseg/metrics.py: -------------------------------------------------------------------------------- 1 | # Adapted from score written by wkentaro 2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py 3 | 4 | import numpy as np 5 | 6 | 7 | class runningScore(object): 8 | def __init__(self, n_classes): 9 | self.n_classes = n_classes 10 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 11 | 12 | def _fast_hist(self, label_true, label_pred, n_class): 13 | mask = (label_true >= 0) & (label_true < n_class) 14 | hist = np.bincount( 15 | n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2 16 | ).reshape(n_class, n_class) 17 | return hist 18 | 19 | def update(self, label_trues, label_preds): 20 | for lt, lp in zip(label_trues, label_preds): 21 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) 22 | 23 | def get_scores(self): 24 | """Returns accuracy score evaluation result. 25 | - overall accuracy 26 | - mean accuracy 27 | - mean IU 28 | - fwavacc 29 | """ 30 | hist = self.confusion_matrix 31 | acc = np.diag(hist).sum() / hist.sum() 32 | acc_cls = np.diag(hist) / hist.sum(axis=1) 33 | acc_cls = np.nanmean(acc_cls) 34 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 35 | mean_iu = np.nanmean(iu) 36 | freq = hist.sum(axis=1) / hist.sum() 37 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 38 | cls_iu = dict(zip(range(self.n_classes), iu)) 39 | 40 | return ( 41 | { 42 | "Overall Acc: \t": acc, 43 | "Mean Acc : \t": acc_cls, 44 | "FreqW Acc : \t": fwavacc, 45 | "Mean IoU : \t": mean_iu, 46 | }, 47 | cls_iu, 48 | ) 49 | 50 | def reset(self): 51 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 52 | 53 | 54 | class averageMeter(object): 55 | """Computes and stores the average and current value""" 56 | 57 | def __init__(self): 58 | self.reset() 59 | 60 | def reset(self): 61 | self.val = 0 62 | self.avg = 0 63 | self.sum = 0 64 | self.count = 0 65 | 66 | def update(self, val, n=1): 67 | self.val = val 68 | self.sum += val * n 69 | self.count += n 70 | self.avg = self.sum / self.count 71 | -------------------------------------------------------------------------------- /ptsemseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torchvision.models as models 3 | 4 | from ptsemseg.models.fcn import fcn8s, fcn16s, fcn32s 5 | from ptsemseg.models.segnet import segnet 6 | from ptsemseg.models.unet import unet 7 | from ptsemseg.models.pspnet import pspnet 8 | from ptsemseg.models.icnet import icnet 9 | from ptsemseg.models.linknet import linknet 10 | from ptsemseg.models.frrn import frrn 11 | 12 | 13 | def get_model(model_dict, n_classes, version=None): 14 | name = model_dict["arch"] 15 | model = _get_model_instance(name) 16 | param_dict = copy.deepcopy(model_dict) 17 | param_dict.pop("arch") 18 | 19 | if name in ["frrnA", "frrnB"]: 20 | model = model(n_classes, **param_dict) 21 | 22 | elif name in ["fcn32s", "fcn16s", "fcn8s"]: 23 | model = model(n_classes=n_classes, **param_dict) 24 | vgg16 = models.vgg16(pretrained=True) 25 | model.init_vgg16_params(vgg16) 26 | 27 | elif name == "segnet": 28 | model = model(n_classes=n_classes, **param_dict) 29 | vgg16 = models.vgg16(pretrained=True) 30 | model.init_vgg16_params(vgg16) 31 | 32 | elif name == "unet": 33 | model = model(n_classes=n_classes, **param_dict) 34 | 35 | elif name == "pspnet": 36 | model = model(n_classes=n_classes, **param_dict) 37 | 38 | elif name == "icnet": 39 | model = model(n_classes=n_classes, **param_dict) 40 | 41 | elif name == "icnetBN": 42 | model = model(n_classes=n_classes, **param_dict) 43 | 44 | else: 45 | model = model(n_classes=n_classes, **param_dict) 46 | 47 | return model 48 | 49 | 50 | def _get_model_instance(name): 51 | try: 52 | return { 53 | "fcn32s": fcn32s, 54 | "fcn8s": fcn8s, 55 | "fcn16s": fcn16s, 56 | "unet": unet, 57 | "segnet": segnet, 58 | "pspnet": pspnet, 59 | "icnet": icnet, 60 | "icnetBN": icnet, 61 | "linknet": linknet, 62 | "frrnA": frrn, 63 | "frrnB": frrn, 64 | }[name] 65 | except: 66 | raise ("Model {} not available".format(name)) 67 | -------------------------------------------------------------------------------- /ptsemseg/models/fcn.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from ptsemseg.models.utils import get_upsampling_weight 7 | from ptsemseg.loss import cross_entropy2d 8 | 9 | 10 | # FCN32s 11 | class fcn32s(nn.Module): 12 | def __init__(self, n_classes=21, learned_billinear=False): 13 | super(fcn32s, self).__init__() 14 | self.learned_billinear = learned_billinear 15 | self.n_classes = n_classes 16 | self.loss = functools.partial(cross_entropy2d, size_average=False) 17 | 18 | self.conv_block1 = nn.Sequential( 19 | nn.Conv2d(3, 64, 3, padding=100), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(64, 64, 3, padding=1), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(2, stride=2, ceil_mode=True), 24 | ) 25 | 26 | self.conv_block2 = nn.Sequential( 27 | nn.Conv2d(64, 128, 3, padding=1), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(128, 128, 3, padding=1), 30 | nn.ReLU(inplace=True), 31 | nn.MaxPool2d(2, stride=2, ceil_mode=True), 32 | ) 33 | 34 | self.conv_block3 = nn.Sequential( 35 | nn.Conv2d(128, 256, 3, padding=1), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(256, 256, 3, padding=1), 38 | nn.ReLU(inplace=True), 39 | nn.Conv2d(256, 256, 3, padding=1), 40 | nn.ReLU(inplace=True), 41 | nn.MaxPool2d(2, stride=2, ceil_mode=True), 42 | ) 43 | 44 | self.conv_block4 = nn.Sequential( 45 | nn.Conv2d(256, 512, 3, padding=1), 46 | nn.ReLU(inplace=True), 47 | nn.Conv2d(512, 512, 3, padding=1), 48 | nn.ReLU(inplace=True), 49 | nn.Conv2d(512, 512, 3, padding=1), 50 | nn.ReLU(inplace=True), 51 | nn.MaxPool2d(2, stride=2, ceil_mode=True), 52 | ) 53 | 54 | self.conv_block5 = nn.Sequential( 55 | nn.Conv2d(512, 512, 3, padding=1), 56 | nn.ReLU(inplace=True), 57 | nn.Conv2d(512, 512, 3, padding=1), 58 | nn.ReLU(inplace=True), 59 | nn.Conv2d(512, 512, 3, padding=1), 60 | nn.ReLU(inplace=True), 61 | nn.MaxPool2d(2, stride=2, ceil_mode=True), 62 | ) 63 | 64 | self.classifier = nn.Sequential( 65 | nn.Conv2d(512, 4096, 7), 66 | nn.ReLU(inplace=True), 67 | nn.Dropout2d(), 68 | nn.Conv2d(4096, 4096, 1), 69 | nn.ReLU(inplace=True), 70 | nn.Dropout2d(), 71 | nn.Conv2d(4096, self.n_classes, 1), 72 | ) 73 | 74 | if self.learned_billinear: 75 | raise NotImplementedError 76 | 77 | def forward(self, x): 78 | conv1 = self.conv_block1(x) 79 | conv2 = self.conv_block2(conv1) 80 | conv3 = self.conv_block3(conv2) 81 | conv4 = self.conv_block4(conv3) 82 | conv5 = self.conv_block5(conv4) 83 | 84 | score = self.classifier(conv5) 85 | 86 | out = F.upsample(score, x.size()[2:]) 87 | 88 | return out 89 | 90 | def init_vgg16_params(self, vgg16, copy_fc8=True): 91 | blocks = [ 92 | self.conv_block1, 93 | self.conv_block2, 94 | self.conv_block3, 95 | self.conv_block4, 96 | self.conv_block5, 97 | ] 98 | 99 | ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] 100 | features = list(vgg16.features.children()) 101 | 102 | for idx, conv_block in enumerate(blocks): 103 | for l1, l2 in zip(features[ranges[idx][0] : ranges[idx][1]], conv_block): 104 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 105 | assert l1.weight.size() == l2.weight.size() 106 | assert l1.bias.size() == l2.bias.size() 107 | l2.weight.data = l1.weight.data 108 | l2.bias.data = l1.bias.data 109 | for i1, i2 in zip([0, 3], [0, 3]): 110 | l1 = vgg16.classifier[i1] 111 | l2 = self.classifier[i2] 112 | l2.weight.data = l1.weight.data.view(l2.weight.size()) 113 | l2.bias.data = l1.bias.data.view(l2.bias.size()) 114 | n_class = self.classifier[6].weight.size()[0] 115 | if copy_fc8: 116 | l1 = vgg16.classifier[6] 117 | l2 = self.classifier[6] 118 | l2.weight.data = l1.weight.data[:n_class, :].view(l2.weight.size()) 119 | l2.bias.data = l1.bias.data[:n_class] 120 | 121 | 122 | class fcn16s(nn.Module): 123 | def __init__(self, n_classes=21, learned_billinear=False): 124 | super(fcn16s, self).__init__() 125 | self.learned_billinear = learned_billinear 126 | self.n_classes = n_classes 127 | self.loss = functools.partial(cross_entropy2d, size_average=False) 128 | 129 | self.conv_block1 = nn.Sequential( 130 | nn.Conv2d(3, 64, 3, padding=100), 131 | nn.ReLU(inplace=True), 132 | nn.Conv2d(64, 64, 3, padding=1), 133 | nn.ReLU(inplace=True), 134 | nn.MaxPool2d(2, stride=2, ceil_mode=True), 135 | ) 136 | 137 | self.conv_block2 = nn.Sequential( 138 | nn.Conv2d(64, 128, 3, padding=1), 139 | nn.ReLU(inplace=True), 140 | nn.Conv2d(128, 128, 3, padding=1), 141 | nn.ReLU(inplace=True), 142 | nn.MaxPool2d(2, stride=2, ceil_mode=True), 143 | ) 144 | 145 | self.conv_block3 = nn.Sequential( 146 | nn.Conv2d(128, 256, 3, padding=1), 147 | nn.ReLU(inplace=True), 148 | nn.Conv2d(256, 256, 3, padding=1), 149 | nn.ReLU(inplace=True), 150 | nn.Conv2d(256, 256, 3, padding=1), 151 | nn.ReLU(inplace=True), 152 | nn.MaxPool2d(2, stride=2, ceil_mode=True), 153 | ) 154 | 155 | self.conv_block4 = nn.Sequential( 156 | nn.Conv2d(256, 512, 3, padding=1), 157 | nn.ReLU(inplace=True), 158 | nn.Conv2d(512, 512, 3, padding=1), 159 | nn.ReLU(inplace=True), 160 | nn.Conv2d(512, 512, 3, padding=1), 161 | nn.ReLU(inplace=True), 162 | nn.MaxPool2d(2, stride=2, ceil_mode=True), 163 | ) 164 | 165 | self.conv_block5 = nn.Sequential( 166 | nn.Conv2d(512, 512, 3, padding=1), 167 | nn.ReLU(inplace=True), 168 | nn.Conv2d(512, 512, 3, padding=1), 169 | nn.ReLU(inplace=True), 170 | nn.Conv2d(512, 512, 3, padding=1), 171 | nn.ReLU(inplace=True), 172 | nn.MaxPool2d(2, stride=2, ceil_mode=True), 173 | ) 174 | 175 | self.classifier = nn.Sequential( 176 | nn.Conv2d(512, 4096, 7), 177 | nn.ReLU(inplace=True), 178 | nn.Dropout2d(), 179 | nn.Conv2d(4096, 4096, 1), 180 | nn.ReLU(inplace=True), 181 | nn.Dropout2d(), 182 | nn.Conv2d(4096, self.n_classes, 1), 183 | ) 184 | 185 | self.score_pool4 = nn.Conv2d(512, self.n_classes, 1) 186 | 187 | # TODO: Add support for learned upsampling 188 | if self.learned_billinear: 189 | raise NotImplementedError 190 | 191 | def forward(self, x): 192 | conv1 = self.conv_block1(x) 193 | conv2 = self.conv_block2(conv1) 194 | conv3 = self.conv_block3(conv2) 195 | conv4 = self.conv_block4(conv3) 196 | conv5 = self.conv_block5(conv4) 197 | 198 | score = self.classifier(conv5) 199 | score_pool4 = self.score_pool4(conv4) 200 | 201 | score = F.upsample(score, score_pool4.size()[2:]) 202 | score += score_pool4 203 | out = F.upsample(score, x.size()[2:]) 204 | 205 | return out 206 | 207 | def init_vgg16_params(self, vgg16, copy_fc8=True): 208 | blocks = [ 209 | self.conv_block1, 210 | self.conv_block2, 211 | self.conv_block3, 212 | self.conv_block4, 213 | self.conv_block5, 214 | ] 215 | 216 | ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] 217 | features = list(vgg16.features.children()) 218 | 219 | for idx, conv_block in enumerate(blocks): 220 | for l1, l2 in zip(features[ranges[idx][0] : ranges[idx][1]], conv_block): 221 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 222 | # print(idx, l1, l2) 223 | assert l1.weight.size() == l2.weight.size() 224 | assert l1.bias.size() == l2.bias.size() 225 | l2.weight.data = l1.weight.data 226 | l2.bias.data = l1.bias.data 227 | for i1, i2 in zip([0, 3], [0, 3]): 228 | l1 = vgg16.classifier[i1] 229 | l2 = self.classifier[i2] 230 | l2.weight.data = l1.weight.data.view(l2.weight.size()) 231 | l2.bias.data = l1.bias.data.view(l2.bias.size()) 232 | n_class = self.classifier[6].weight.size()[0] 233 | if copy_fc8: 234 | l1 = vgg16.classifier[6] 235 | l2 = self.classifier[6] 236 | l2.weight.data = l1.weight.data[:n_class, :].view(l2.weight.size()) 237 | l2.bias.data = l1.bias.data[:n_class] 238 | 239 | 240 | # FCN 8s 241 | class fcn8s(nn.Module): 242 | def __init__(self, n_classes=21, learned_billinear=True): 243 | super(fcn8s, self).__init__() 244 | self.learned_billinear = learned_billinear 245 | self.n_classes = n_classes 246 | self.loss = functools.partial(cross_entropy2d, size_average=False) 247 | 248 | self.conv_block1 = nn.Sequential( 249 | nn.Conv2d(3, 64, 3, padding=100), 250 | nn.ReLU(inplace=True), 251 | nn.Conv2d(64, 64, 3, padding=1), 252 | nn.ReLU(inplace=True), 253 | nn.MaxPool2d(2, stride=2, ceil_mode=True), 254 | ) 255 | 256 | self.conv_block2 = nn.Sequential( 257 | nn.Conv2d(64, 128, 3, padding=1), 258 | nn.ReLU(inplace=True), 259 | nn.Conv2d(128, 128, 3, padding=1), 260 | nn.ReLU(inplace=True), 261 | nn.MaxPool2d(2, stride=2, ceil_mode=True), 262 | ) 263 | 264 | self.conv_block3 = nn.Sequential( 265 | nn.Conv2d(128, 256, 3, padding=1), 266 | nn.ReLU(inplace=True), 267 | nn.Conv2d(256, 256, 3, padding=1), 268 | nn.ReLU(inplace=True), 269 | nn.Conv2d(256, 256, 3, padding=1), 270 | nn.ReLU(inplace=True), 271 | nn.MaxPool2d(2, stride=2, ceil_mode=True), 272 | ) 273 | 274 | self.conv_block4 = nn.Sequential( 275 | nn.Conv2d(256, 512, 3, padding=1), 276 | nn.ReLU(inplace=True), 277 | nn.Conv2d(512, 512, 3, padding=1), 278 | nn.ReLU(inplace=True), 279 | nn.Conv2d(512, 512, 3, padding=1), 280 | nn.ReLU(inplace=True), 281 | nn.MaxPool2d(2, stride=2, ceil_mode=True), 282 | ) 283 | 284 | self.conv_block5 = nn.Sequential( 285 | nn.Conv2d(512, 512, 3, padding=1), 286 | nn.ReLU(inplace=True), 287 | nn.Conv2d(512, 512, 3, padding=1), 288 | nn.ReLU(inplace=True), 289 | nn.Conv2d(512, 512, 3, padding=1), 290 | nn.ReLU(inplace=True), 291 | nn.MaxPool2d(2, stride=2, ceil_mode=True), 292 | ) 293 | 294 | self.classifier = nn.Sequential( 295 | nn.Conv2d(512, 4096, 7), 296 | nn.ReLU(inplace=True), 297 | nn.Dropout2d(), 298 | nn.Conv2d(4096, 4096, 1), 299 | nn.ReLU(inplace=True), 300 | nn.Dropout2d(), 301 | nn.Conv2d(4096, self.n_classes, 1), 302 | ) 303 | 304 | self.score_pool4 = nn.Conv2d(512, self.n_classes, 1) 305 | self.score_pool3 = nn.Conv2d(256, self.n_classes, 1) 306 | 307 | if self.learned_billinear: 308 | self.upscore2 = nn.ConvTranspose2d( 309 | self.n_classes, self.n_classes, 4, stride=2, bias=False 310 | ) 311 | self.upscore4 = nn.ConvTranspose2d( 312 | self.n_classes, self.n_classes, 4, stride=2, bias=False 313 | ) 314 | self.upscore8 = nn.ConvTranspose2d( 315 | self.n_classes, self.n_classes, 16, stride=8, bias=False 316 | ) 317 | 318 | for m in self.modules(): 319 | if isinstance(m, nn.ConvTranspose2d): 320 | m.weight.data.copy_( 321 | get_upsampling_weight(m.in_channels, m.out_channels, m.kernel_size[0]) 322 | ) 323 | 324 | def forward(self, x): 325 | conv1 = self.conv_block1(x) 326 | conv2 = self.conv_block2(conv1) 327 | conv3 = self.conv_block3(conv2) 328 | conv4 = self.conv_block4(conv3) 329 | conv5 = self.conv_block5(conv4) 330 | 331 | score = self.classifier(conv5) 332 | 333 | if self.learned_billinear: 334 | upscore2 = self.upscore2(score) 335 | score_pool4c = self.score_pool4(conv4)[ 336 | :, :, 5 : 5 + upscore2.size()[2], 5 : 5 + upscore2.size()[3] 337 | ] 338 | upscore_pool4 = self.upscore4(upscore2 + score_pool4c) 339 | 340 | score_pool3c = self.score_pool3(conv3)[ 341 | :, :, 9 : 9 + upscore_pool4.size()[2], 9 : 9 + upscore_pool4.size()[3] 342 | ] 343 | 344 | out = self.upscore8(score_pool3c + upscore_pool4)[ 345 | :, :, 31 : 31 + x.size()[2], 31 : 31 + x.size()[3] 346 | ] 347 | return out.contiguous() 348 | 349 | else: 350 | score_pool4 = self.score_pool4(conv4) 351 | score_pool3 = self.score_pool3(conv3) 352 | score = F.upsample(score, score_pool4.size()[2:]) 353 | score += score_pool4 354 | score = F.upsample(score, score_pool3.size()[2:]) 355 | score += score_pool3 356 | out = F.upsample(score, x.size()[2:]) 357 | 358 | return out 359 | 360 | def init_vgg16_params(self, vgg16, copy_fc8=True): 361 | blocks = [ 362 | self.conv_block1, 363 | self.conv_block2, 364 | self.conv_block3, 365 | self.conv_block4, 366 | self.conv_block5, 367 | ] 368 | 369 | ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] 370 | features = list(vgg16.features.children()) 371 | 372 | for idx, conv_block in enumerate(blocks): 373 | for l1, l2 in zip(features[ranges[idx][0] : ranges[idx][1]], conv_block): 374 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 375 | assert l1.weight.size() == l2.weight.size() 376 | assert l1.bias.size() == l2.bias.size() 377 | l2.weight.data = l1.weight.data 378 | l2.bias.data = l1.bias.data 379 | for i1, i2 in zip([0, 3], [0, 3]): 380 | l1 = vgg16.classifier[i1] 381 | l2 = self.classifier[i2] 382 | l2.weight.data = l1.weight.data.view(l2.weight.size()) 383 | l2.bias.data = l1.bias.data.view(l2.bias.size()) 384 | n_class = self.classifier[6].weight.size()[0] 385 | if copy_fc8: 386 | l1 = vgg16.classifier[6] 387 | l2 = self.classifier[6] 388 | l2.weight.data = l1.weight.data[:n_class, :].view(l2.weight.size()) 389 | l2.bias.data = l1.bias.data[:n_class] 390 | -------------------------------------------------------------------------------- /ptsemseg/models/frrn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ptsemseg.models.utils import FRRU, RU, conv2DBatchNormRelu, conv2DGroupNormRelu 6 | 7 | frrn_specs_dic = { 8 | "A": { 9 | "encoder": [[3, 96, 2], [4, 192, 4], [2, 384, 8], [2, 384, 16]], 10 | "decoder": [[2, 192, 8], [2, 192, 4], [2, 48, 2]], 11 | }, 12 | "B": { 13 | "encoder": [[3, 96, 2], [4, 192, 4], [2, 384, 8], [2, 384, 16], [2, 384, 32]], 14 | "decoder": [[2, 192, 16], [2, 192, 8], [2, 192, 4], [2, 48, 2]], 15 | }, 16 | } 17 | 18 | 19 | class frrn(nn.Module): 20 | """ 21 | Full Resolution Residual Networks for Semantic Segmentation 22 | URL: https://arxiv.org/abs/1611.08323 23 | 24 | References: 25 | 1) Original Author's code: https://github.com/TobyPDE/FRRN 26 | 2) TF implementation by @kiwonjoon: https://github.com/hiwonjoon/tf-frrn 27 | """ 28 | 29 | def __init__(self, n_classes=21, model_type="B", group_norm=False, n_groups=16): 30 | super(frrn, self).__init__() 31 | self.n_classes = n_classes 32 | self.model_type = model_type 33 | self.group_norm = group_norm 34 | self.n_groups = n_groups 35 | 36 | if self.group_norm: 37 | self.conv1 = conv2DGroupNormRelu(3, 48, 5, 1, 2) 38 | else: 39 | self.conv1 = conv2DBatchNormRelu(3, 48, 5, 1, 2) 40 | 41 | self.up_residual_units = [] 42 | self.down_residual_units = [] 43 | for i in range(3): 44 | self.up_residual_units.append( 45 | RU( 46 | channels=48, 47 | kernel_size=3, 48 | strides=1, 49 | group_norm=self.group_norm, 50 | n_groups=self.n_groups, 51 | ) 52 | ) 53 | self.down_residual_units.append( 54 | RU( 55 | channels=48, 56 | kernel_size=3, 57 | strides=1, 58 | group_norm=self.group_norm, 59 | n_groups=self.n_groups, 60 | ) 61 | ) 62 | 63 | self.up_residual_units = nn.ModuleList(self.up_residual_units) 64 | self.down_residual_units = nn.ModuleList(self.down_residual_units) 65 | 66 | self.split_conv = nn.Conv2d(48, 32, kernel_size=1, padding=0, stride=1, bias=False) 67 | 68 | # each spec is as (n_blocks, channels, scale) 69 | self.encoder_frru_specs = frrn_specs_dic[self.model_type]["encoder"] 70 | 71 | self.decoder_frru_specs = frrn_specs_dic[self.model_type]["decoder"] 72 | 73 | # encoding 74 | prev_channels = 48 75 | self.encoding_frrus = {} 76 | for n_blocks, channels, scale in self.encoder_frru_specs: 77 | for block in range(n_blocks): 78 | key = "_".join(map(str, ["encoding_frru", n_blocks, channels, scale, block])) 79 | setattr( 80 | self, 81 | key, 82 | FRRU( 83 | prev_channels=prev_channels, 84 | out_channels=channels, 85 | scale=scale, 86 | group_norm=self.group_norm, 87 | n_groups=self.n_groups, 88 | ), 89 | ) 90 | prev_channels = channels 91 | 92 | # decoding 93 | self.decoding_frrus = {} 94 | for n_blocks, channels, scale in self.decoder_frru_specs: 95 | # pass through decoding FRRUs 96 | for block in range(n_blocks): 97 | key = "_".join(map(str, ["decoding_frru", n_blocks, channels, scale, block])) 98 | setattr( 99 | self, 100 | key, 101 | FRRU( 102 | prev_channels=prev_channels, 103 | out_channels=channels, 104 | scale=scale, 105 | group_norm=self.group_norm, 106 | n_groups=self.n_groups, 107 | ), 108 | ) 109 | prev_channels = channels 110 | 111 | self.merge_conv = nn.Conv2d( 112 | prev_channels + 32, 48, kernel_size=1, padding=0, stride=1, bias=False 113 | ) 114 | 115 | self.classif_conv = nn.Conv2d( 116 | 48, self.n_classes, kernel_size=1, padding=0, stride=1, bias=True 117 | ) 118 | 119 | def forward(self, x): 120 | 121 | # pass to initial conv 122 | x = self.conv1(x) 123 | 124 | # pass through residual units 125 | for i in range(3): 126 | x = self.up_residual_units[i](x) 127 | 128 | # divide stream 129 | y = x 130 | z = self.split_conv(x) 131 | 132 | prev_channels = 48 133 | # encoding 134 | for n_blocks, channels, scale in self.encoder_frru_specs: 135 | # maxpool bigger feature map 136 | y_pooled = F.max_pool2d(y, stride=2, kernel_size=2, padding=0) 137 | # pass through encoding FRRUs 138 | for block in range(n_blocks): 139 | key = "_".join(map(str, ["encoding_frru", n_blocks, channels, scale, block])) 140 | y, z = getattr(self, key)(y_pooled, z) 141 | prev_channels = channels 142 | 143 | # decoding 144 | for n_blocks, channels, scale in self.decoder_frru_specs: 145 | # bilinear upsample smaller feature map 146 | upsample_size = torch.Size([_s * 2 for _s in y.size()[-2:]]) 147 | y_upsampled = F.upsample(y, size=upsample_size, mode="bilinear", align_corners=True) 148 | # pass through decoding FRRUs 149 | for block in range(n_blocks): 150 | key = "_".join(map(str, ["decoding_frru", n_blocks, channels, scale, block])) 151 | # print("Incoming FRRU Size: ", key, y_upsampled.shape, z.shape) 152 | y, z = getattr(self, key)(y_upsampled, z) 153 | # print("Outgoing FRRU Size: ", key, y.shape, z.shape) 154 | prev_channels = channels 155 | 156 | # merge streams 157 | x = torch.cat( 158 | [F.upsample(y, scale_factor=2, mode="bilinear", align_corners=True), z], dim=1 159 | ) 160 | x = self.merge_conv(x) 161 | 162 | # pass through residual units 163 | for i in range(3): 164 | x = self.down_residual_units[i](x) 165 | 166 | # final 1x1 conv to get classification 167 | x = self.classif_conv(x) 168 | 169 | return x 170 | -------------------------------------------------------------------------------- /ptsemseg/models/icnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import Variable 7 | 8 | from ptsemseg import caffe_pb2 9 | from ptsemseg.models.utils import ( 10 | get_interp_size, 11 | cascadeFeatureFusion, 12 | conv2DBatchNormRelu, 13 | residualBlockPSP, 14 | pyramidPooling, 15 | ) 16 | from ptsemseg.loss.loss import multi_scale_cross_entropy2d 17 | 18 | icnet_specs = { 19 | "cityscapes": {"n_classes": 19, "input_size": (1025, 2049), "block_config": [3, 4, 6, 3]} 20 | } 21 | 22 | 23 | class icnet(nn.Module): 24 | 25 | """ 26 | Image Cascade Network 27 | URL: https://arxiv.org/abs/1704.08545 28 | 29 | References: 30 | 1) Original Author's code: https://github.com/hszhao/ICNet 31 | 2) Chainer implementation by @mitmul: https://github.com/mitmul/chainer-pspnet 32 | 3) TensorFlow implementation by @hellochick: https://github.com/hellochick/ICNet-tensorflow 33 | 34 | """ 35 | 36 | def __init__( 37 | self, 38 | n_classes=19, 39 | block_config=[3, 4, 6, 3], 40 | input_size=(1025, 2049), 41 | version=None, 42 | is_batchnorm=True, 43 | ): 44 | 45 | super(icnet, self).__init__() 46 | 47 | bias = not is_batchnorm 48 | 49 | self.block_config = ( 50 | icnet_specs[version]["block_config"] if version is not None else block_config 51 | ) 52 | self.n_classes = icnet_specs[version]["n_classes"] if version is not None else n_classes 53 | self.input_size = icnet_specs[version]["input_size"] if version is not None else input_size 54 | 55 | # Encoder 56 | self.convbnrelu1_1 = conv2DBatchNormRelu( 57 | in_channels=3, 58 | k_size=3, 59 | n_filters=32, 60 | padding=1, 61 | stride=2, 62 | bias=bias, 63 | is_batchnorm=is_batchnorm, 64 | ) 65 | self.convbnrelu1_2 = conv2DBatchNormRelu( 66 | in_channels=32, 67 | k_size=3, 68 | n_filters=32, 69 | padding=1, 70 | stride=1, 71 | bias=bias, 72 | is_batchnorm=is_batchnorm, 73 | ) 74 | self.convbnrelu1_3 = conv2DBatchNormRelu( 75 | in_channels=32, 76 | k_size=3, 77 | n_filters=64, 78 | padding=1, 79 | stride=1, 80 | bias=bias, 81 | is_batchnorm=is_batchnorm, 82 | ) 83 | 84 | # Vanilla Residual Blocks 85 | self.res_block2 = residualBlockPSP( 86 | self.block_config[0], 64, 32, 128, 1, 1, is_batchnorm=is_batchnorm 87 | ) 88 | self.res_block3_conv = residualBlockPSP( 89 | self.block_config[1], 90 | 128, 91 | 64, 92 | 256, 93 | 2, 94 | 1, 95 | include_range="conv", 96 | is_batchnorm=is_batchnorm, 97 | ) 98 | self.res_block3_identity = residualBlockPSP( 99 | self.block_config[1], 100 | 128, 101 | 64, 102 | 256, 103 | 2, 104 | 1, 105 | include_range="identity", 106 | is_batchnorm=is_batchnorm, 107 | ) 108 | 109 | # Dilated Residual Blocks 110 | self.res_block4 = residualBlockPSP( 111 | self.block_config[2], 256, 128, 512, 1, 2, is_batchnorm=is_batchnorm 112 | ) 113 | self.res_block5 = residualBlockPSP( 114 | self.block_config[3], 512, 256, 1024, 1, 4, is_batchnorm=is_batchnorm 115 | ) 116 | 117 | # Pyramid Pooling Module 118 | self.pyramid_pooling = pyramidPooling( 119 | 1024, [6, 3, 2, 1], model_name="icnet", fusion_mode="sum", is_batchnorm=is_batchnorm 120 | ) 121 | 122 | # Final conv layer with kernel 1 in sub4 branch 123 | self.conv5_4_k1 = conv2DBatchNormRelu( 124 | in_channels=1024, 125 | k_size=1, 126 | n_filters=256, 127 | padding=0, 128 | stride=1, 129 | bias=bias, 130 | is_batchnorm=is_batchnorm, 131 | ) 132 | 133 | # High-resolution (sub1) branch 134 | self.convbnrelu1_sub1 = conv2DBatchNormRelu( 135 | in_channels=3, 136 | k_size=3, 137 | n_filters=32, 138 | padding=1, 139 | stride=2, 140 | bias=bias, 141 | is_batchnorm=is_batchnorm, 142 | ) 143 | self.convbnrelu2_sub1 = conv2DBatchNormRelu( 144 | in_channels=32, 145 | k_size=3, 146 | n_filters=32, 147 | padding=1, 148 | stride=2, 149 | bias=bias, 150 | is_batchnorm=is_batchnorm, 151 | ) 152 | self.convbnrelu3_sub1 = conv2DBatchNormRelu( 153 | in_channels=32, 154 | k_size=3, 155 | n_filters=64, 156 | padding=1, 157 | stride=2, 158 | bias=bias, 159 | is_batchnorm=is_batchnorm, 160 | ) 161 | self.classification = nn.Conv2d(128, self.n_classes, 1, 1, 0) 162 | 163 | # Cascade Feature Fusion Units 164 | self.cff_sub24 = cascadeFeatureFusion( 165 | self.n_classes, 256, 256, 128, is_batchnorm=is_batchnorm 166 | ) 167 | self.cff_sub12 = cascadeFeatureFusion( 168 | self.n_classes, 128, 64, 128, is_batchnorm=is_batchnorm 169 | ) 170 | 171 | # Define auxiliary loss function 172 | self.loss = multi_scale_cross_entropy2d 173 | 174 | def forward(self, x): 175 | h, w = x.shape[2:] 176 | 177 | # H, W -> H/2, W/2 178 | x_sub2 = F.interpolate( 179 | x, size=get_interp_size(x, s_factor=2), mode="bilinear", align_corners=True 180 | ) 181 | 182 | # H/2, W/2 -> H/4, W/4 183 | x_sub2 = self.convbnrelu1_1(x_sub2) 184 | x_sub2 = self.convbnrelu1_2(x_sub2) 185 | x_sub2 = self.convbnrelu1_3(x_sub2) 186 | 187 | # H/4, W/4 -> H/8, W/8 188 | x_sub2 = F.max_pool2d(x_sub2, 3, 2, 1) 189 | 190 | # H/8, W/8 -> H/16, W/16 191 | x_sub2 = self.res_block2(x_sub2) 192 | x_sub2 = self.res_block3_conv(x_sub2) 193 | # H/16, W/16 -> H/32, W/32 194 | x_sub4 = F.interpolate( 195 | x_sub2, size=get_interp_size(x_sub2, s_factor=2), mode="bilinear", align_corners=True 196 | ) 197 | x_sub4 = self.res_block3_identity(x_sub4) 198 | 199 | x_sub4 = self.res_block4(x_sub4) 200 | x_sub4 = self.res_block5(x_sub4) 201 | 202 | x_sub4 = self.pyramid_pooling(x_sub4) 203 | x_sub4 = self.conv5_4_k1(x_sub4) 204 | 205 | x_sub1 = self.convbnrelu1_sub1(x) 206 | x_sub1 = self.convbnrelu2_sub1(x_sub1) 207 | x_sub1 = self.convbnrelu3_sub1(x_sub1) 208 | 209 | x_sub24, sub4_cls = self.cff_sub24(x_sub4, x_sub2) 210 | x_sub12, sub24_cls = self.cff_sub12(x_sub24, x_sub1) 211 | 212 | x_sub12 = F.interpolate( 213 | x_sub12, size=get_interp_size(x_sub12, z_factor=2), mode="bilinear", align_corners=True 214 | ) 215 | x_sub4 = self.res_block3_identity(x_sub4) 216 | sub124_cls = self.classification(x_sub12) 217 | 218 | if self.training: 219 | return (sub124_cls, sub24_cls, sub4_cls) 220 | else: 221 | sub124_cls = F.interpolate( 222 | sub124_cls, 223 | size=get_interp_size(sub124_cls, z_factor=4), 224 | mode="bilinear", 225 | align_corners=True, 226 | ) 227 | return sub124_cls 228 | 229 | def load_pretrained_model(self, model_path): 230 | """ 231 | Load weights from caffemodel w/o caffe dependency 232 | and plug them in corresponding modules 233 | """ 234 | # My eyes and my heart both hurt when writing this method 235 | 236 | # Only care about layer_types that have trainable parameters 237 | ltypes = [ 238 | "BNData", 239 | "ConvolutionData", 240 | "HoleConvolutionData", 241 | "Convolution", 242 | ] # Convolution type for conv3_sub1_proj 243 | 244 | def _get_layer_params(layer, ltype): 245 | 246 | if ltype == "BNData": 247 | gamma = np.array(layer.blobs[0].data) 248 | beta = np.array(layer.blobs[1].data) 249 | mean = np.array(layer.blobs[2].data) 250 | var = np.array(layer.blobs[3].data) 251 | return [mean, var, gamma, beta] 252 | 253 | elif ltype in ["ConvolutionData", "HoleConvolutionData", "Convolution"]: 254 | is_bias = layer.convolution_param.bias_term 255 | weights = np.array(layer.blobs[0].data) 256 | bias = [] 257 | if is_bias: 258 | bias = np.array(layer.blobs[1].data) 259 | return [weights, bias] 260 | 261 | elif ltype == "InnerProduct": 262 | raise Exception("Fully connected layers {}, not supported".format(ltype)) 263 | 264 | else: 265 | raise Exception("Unkown layer type {}".format(ltype)) 266 | 267 | net = caffe_pb2.NetParameter() 268 | with open(model_path, "rb") as model_file: 269 | net.MergeFromString(model_file.read()) 270 | 271 | # dict formatted as -> key: :: value: 272 | layer_types = {} 273 | # dict formatted as -> key: :: value:[] 274 | layer_params = {} 275 | 276 | for l in net.layer: 277 | lname = l.name 278 | ltype = l.type 279 | lbottom = l.bottom 280 | ltop = l.top 281 | if ltype in ltypes: 282 | print("Processing layer {} | {}, {}".format(lname, lbottom, ltop)) 283 | layer_types[lname] = ltype 284 | layer_params[lname] = _get_layer_params(l, ltype) 285 | # if len(l.blobs) > 0: 286 | # print(lname, ltype, lbottom, ltop, len(l.blobs)) 287 | 288 | # Set affine=False for all batchnorm modules 289 | def _no_affine_bn(module=None): 290 | if isinstance(module, nn.BatchNorm2d): 291 | module.affine = False 292 | 293 | if len([m for m in module.children()]) > 0: 294 | for child in module.children(): 295 | _no_affine_bn(child) 296 | 297 | # _no_affine_bn(self) 298 | 299 | def _transfer_conv(layer_name, module): 300 | weights, bias = layer_params[layer_name] 301 | w_shape = np.array(module.weight.size()) 302 | 303 | print( 304 | "CONV {}: Original {} and trans weights {}".format( 305 | layer_name, w_shape, weights.shape 306 | ) 307 | ) 308 | 309 | module.weight.data.copy_(torch.from_numpy(weights).view_as(module.weight)) 310 | 311 | if len(bias) != 0: 312 | b_shape = np.array(module.bias.size()) 313 | print( 314 | "CONV {}: Original {} and trans bias {}".format(layer_name, b_shape, bias.shape) 315 | ) 316 | module.bias.data.copy_(torch.from_numpy(bias).view_as(module.bias)) 317 | 318 | def _transfer_bn(conv_layer_name, bn_module): 319 | mean, var, gamma, beta = layer_params[conv_layer_name + "/bn"] 320 | print( 321 | "BN {}: Original {} and trans weights {}".format( 322 | conv_layer_name, bn_module.running_mean.size(), mean.shape 323 | ) 324 | ) 325 | bn_module.running_mean.copy_(torch.from_numpy(mean).view_as(bn_module.running_mean)) 326 | bn_module.running_var.copy_(torch.from_numpy(var).view_as(bn_module.running_var)) 327 | bn_module.weight.data.copy_(torch.from_numpy(gamma).view_as(bn_module.weight)) 328 | bn_module.bias.data.copy_(torch.from_numpy(beta).view_as(bn_module.bias)) 329 | 330 | def _transfer_conv_bn(conv_layer_name, mother_module): 331 | conv_module = mother_module[0] 332 | _transfer_conv(conv_layer_name, conv_module) 333 | 334 | if conv_layer_name + "/bn" in layer_params.keys(): 335 | bn_module = mother_module[1] 336 | _transfer_bn(conv_layer_name, bn_module) 337 | 338 | def _transfer_residual(block_name, block): 339 | block_module, n_layers = block[0], block[1] 340 | prefix = block_name[:5] 341 | 342 | if ("bottleneck" in block_name) or ("identity" not in block_name): # Conv block 343 | bottleneck = block_module.layers[0] 344 | bottleneck_conv_bn_dic = { 345 | prefix + "_1_1x1_reduce": bottleneck.cbr1.cbr_unit, 346 | prefix + "_1_3x3": bottleneck.cbr2.cbr_unit, 347 | prefix + "_1_1x1_proj": bottleneck.cb4.cb_unit, 348 | prefix + "_1_1x1_increase": bottleneck.cb3.cb_unit, 349 | } 350 | 351 | for k, v in bottleneck_conv_bn_dic.items(): 352 | _transfer_conv_bn(k, v) 353 | 354 | if ("identity" in block_name) or ("bottleneck" not in block_name): # Identity blocks 355 | base_idx = 2 if "identity" in block_name else 1 356 | 357 | for layer_idx in range(2, n_layers + 1): 358 | residual_layer = block_module.layers[layer_idx - base_idx] 359 | residual_conv_bn_dic = { 360 | "_".join( 361 | map(str, [prefix, layer_idx, "1x1_reduce"]) 362 | ): residual_layer.cbr1.cbr_unit, 363 | "_".join( 364 | map(str, [prefix, layer_idx, "3x3"]) 365 | ): residual_layer.cbr2.cbr_unit, 366 | "_".join( 367 | map(str, [prefix, layer_idx, "1x1_increase"]) 368 | ): residual_layer.cb3.cb_unit, 369 | } 370 | 371 | for k, v in residual_conv_bn_dic.items(): 372 | _transfer_conv_bn(k, v) 373 | 374 | convbn_layer_mapping = { 375 | "conv1_1_3x3_s2": self.convbnrelu1_1.cbr_unit, 376 | "conv1_2_3x3": self.convbnrelu1_2.cbr_unit, 377 | "conv1_3_3x3": self.convbnrelu1_3.cbr_unit, 378 | "conv1_sub1": self.convbnrelu1_sub1.cbr_unit, 379 | "conv2_sub1": self.convbnrelu2_sub1.cbr_unit, 380 | "conv3_sub1": self.convbnrelu3_sub1.cbr_unit, 381 | # 'conv5_3_pool6_conv': self.pyramid_pooling.paths[0].cbr_unit, 382 | # 'conv5_3_pool3_conv': self.pyramid_pooling.paths[1].cbr_unit, 383 | # 'conv5_3_pool2_conv': self.pyramid_pooling.paths[2].cbr_unit, 384 | # 'conv5_3_pool1_conv': self.pyramid_pooling.paths[3].cbr_unit, 385 | "conv5_4_k1": self.conv5_4_k1.cbr_unit, 386 | "conv_sub4": self.cff_sub24.low_dilated_conv_bn.cb_unit, 387 | "conv3_1_sub2_proj": self.cff_sub24.high_proj_conv_bn.cb_unit, 388 | "conv_sub2": self.cff_sub12.low_dilated_conv_bn.cb_unit, 389 | "conv3_sub1_proj": self.cff_sub12.high_proj_conv_bn.cb_unit, 390 | } 391 | 392 | residual_layers = { 393 | "conv2": [self.res_block2, self.block_config[0]], 394 | "conv3_bottleneck": [self.res_block3_conv, self.block_config[1]], 395 | "conv3_identity": [self.res_block3_identity, self.block_config[1]], 396 | "conv4": [self.res_block4, self.block_config[2]], 397 | "conv5": [self.res_block5, self.block_config[3]], 398 | } 399 | 400 | # Transfer weights for all non-residual conv+bn layers 401 | for k, v in convbn_layer_mapping.items(): 402 | _transfer_conv_bn(k, v) 403 | 404 | # Transfer weights for final non-bn conv layer 405 | _transfer_conv("conv6_cls", self.classification) 406 | _transfer_conv("conv6_sub4", self.cff_sub24.low_classifier_conv) 407 | _transfer_conv("conv6_sub2", self.cff_sub12.low_classifier_conv) 408 | 409 | # Transfer weights for all residual layers 410 | for k, v in residual_layers.items(): 411 | _transfer_residual(k, v) 412 | 413 | def tile_predict(self, imgs, include_flip_mode=True): 414 | """ 415 | Predict by takin overlapping tiles from the image. 416 | 417 | Strides are adaptively computed from the imgs shape 418 | and input size 419 | 420 | :param imgs: torch.Tensor with shape [N, C, H, W] in BGR format 421 | :param side: int with side length of model input 422 | :param n_classes: int with number of classes in seg output. 423 | """ 424 | 425 | side_x, side_y = self.input_size 426 | n_classes = self.n_classes 427 | n_samples, c, h, w = imgs.shape 428 | # n = int(max(h,w) / float(side) + 1) 429 | n_x = int(h / float(side_x) + 1) 430 | n_y = int(w / float(side_y) + 1) 431 | stride_x = (h - side_x) / float(n_x) 432 | stride_y = (w - side_y) / float(n_y) 433 | 434 | x_ends = [[int(i * stride_x), int(i * stride_x) + side_x] for i in range(n_x + 1)] 435 | y_ends = [[int(i * stride_y), int(i * stride_y) + side_y] for i in range(n_y + 1)] 436 | 437 | pred = np.zeros([n_samples, n_classes, h, w]) 438 | count = np.zeros([h, w]) 439 | 440 | slice_count = 0 441 | for sx, ex in x_ends: 442 | for sy, ey in y_ends: 443 | slice_count += 1 444 | 445 | imgs_slice = imgs[:, :, sx:ex, sy:ey] 446 | if include_flip_mode: 447 | imgs_slice_flip = torch.from_numpy( 448 | np.copy(imgs_slice.cpu().numpy()[:, :, :, ::-1]) 449 | ).float() 450 | 451 | is_model_on_cuda = next(self.parameters()).is_cuda 452 | 453 | inp = Variable(imgs_slice, volatile=True) 454 | if include_flip_mode: 455 | flp = Variable(imgs_slice_flip, volatile=True) 456 | 457 | if is_model_on_cuda: 458 | inp = inp.cuda() 459 | if include_flip_mode: 460 | flp = flp.cuda() 461 | 462 | psub1 = F.softmax(self.forward(inp), dim=1).data.cpu().numpy() 463 | if include_flip_mode: 464 | psub2 = F.softmax(self.forward(flp), dim=1).data.cpu().numpy() 465 | psub = (psub1 + psub2[:, :, :, ::-1]) / 2.0 466 | else: 467 | psub = psub1 468 | 469 | pred[:, :, sx:ex, sy:ey] = psub 470 | count[sx:ex, sy:ey] += 1.0 471 | 472 | score = (pred / count[None, None, ...]).astype(np.float32) 473 | return score / np.expand_dims(score.sum(axis=1), axis=1) 474 | 475 | 476 | # For Testing Purposes only 477 | if __name__ == "__main__": 478 | cd = 0 479 | import os 480 | import scipy.misc as m 481 | from ptsemseg.loader.cityscapes_loader import cityscapesLoader as cl 482 | 483 | ic = icnet(version="cityscapes", is_batchnorm=False) 484 | 485 | # Just need to do this one time 486 | caffemodel_dir_path = "PATH_TO_ICNET_DIR/evaluation/model" 487 | ic.load_pretrained_model( 488 | model_path=os.path.join(caffemodel_dir_path, "icnet_cityscapes_train_30k.caffemodel") 489 | ) 490 | # ic.load_pretrained_model(model_path=os.path.join(caffemodel_dir_path, 491 | # 'icnet_cityscapes_train_30k_bnnomerge.caffemodel')) 492 | # ic.load_pretrained_model(model_path=os.path.join(caffemodel_dir_path, 493 | # 'icnet_cityscapes_trainval_90k.caffemodel')) 494 | # ic.load_pretrained_model(model_path=os.path.join(caffemodel_dir_path, 495 | # 'icnet_cityscapes_trainval_90k_bnnomerge.caffemodel')) 496 | 497 | # ic.load_state_dict(torch.load('ic.pth')) 498 | 499 | ic.float() 500 | ic.cuda(cd) 501 | ic.eval() 502 | 503 | dataset_root_dir = "PATH_TO_CITYSCAPES_DIR" 504 | dst = cl(root=dataset_root_dir) 505 | img = m.imread( 506 | os.path.join( 507 | dataset_root_dir, 508 | "leftImg8bit/demoVideo/stuttgart_00/stuttgart_00_000000_000010_leftImg8bit.png", 509 | ) 510 | ) 511 | m.imsave("test_input.png", img) 512 | orig_size = img.shape[:-1] 513 | img = m.imresize(img, ic.input_size) # uint8 with RGB mode 514 | img = img.transpose(2, 0, 1) 515 | img = img.astype(np.float64) 516 | img -= np.array([123.68, 116.779, 103.939])[:, None, None] 517 | img = np.copy(img[::-1, :, :]) 518 | img = torch.from_numpy(img).float() 519 | img = img.unsqueeze(0) 520 | 521 | out = ic.tile_predict(img) 522 | pred = np.argmax(out, axis=1)[0] 523 | pred = pred.astype(np.float32) 524 | pred = m.imresize(pred, orig_size, "nearest", mode="F") # float32 with F mode 525 | decoded = dst.decode_segmap(pred) 526 | m.imsave("test_output.png", decoded) 527 | # m.imsave('test_output.png', pred) 528 | 529 | checkpoints_dir_path = "checkpoints" 530 | if not os.path.exists(checkpoints_dir_path): 531 | os.mkdir(checkpoints_dir_path) 532 | ic = torch.nn.DataParallel(ic, device_ids=range(torch.cuda.device_count())) 533 | state = {"model_state": ic.state_dict()} 534 | torch.save(state, os.path.join(checkpoints_dir_path, "icnet_cityscapes_train_30k.pth")) 535 | # torch.save(state, os.path.join(checkpoints_dir_path, "icnetBN_cityscapes_train_30k.pth")) 536 | # torch.save(state, os.path.join(checkpoints_dir_path, "icnet_cityscapes_trainval_90k.pth")) 537 | # torch.save(state, os.path.join(checkpoints_dir_path, "icnetBN_cityscapes_trainval_90k.pth")) 538 | print("Output Shape {} \t Input Shape {}".format(out.shape, img.shape)) 539 | -------------------------------------------------------------------------------- /ptsemseg/models/linknet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ptsemseg.models.utils import conv2DBatchNormRelu, linknetUp, residualBlock 4 | 5 | 6 | class linknet(nn.Module): 7 | def __init__( 8 | self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True 9 | ): 10 | super(linknet, self).__init__() 11 | self.is_deconv = is_deconv 12 | self.in_channels = in_channels 13 | self.is_batchnorm = is_batchnorm 14 | self.feature_scale = feature_scale 15 | self.layers = [2, 2, 2, 2] # Currently hardcoded for ResNet-18 16 | 17 | filters = [64, 128, 256, 512] 18 | filters = [x / self.feature_scale for x in filters] 19 | 20 | self.inplanes = filters[0] 21 | 22 | # Encoder 23 | self.convbnrelu1 = conv2DBatchNormRelu( 24 | in_channels=3, k_size=7, n_filters=64, padding=3, stride=2, bias=False 25 | ) 26 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 27 | 28 | block = residualBlock 29 | self.encoder1 = self._make_layer(block, filters[0], self.layers[0]) 30 | self.encoder2 = self._make_layer(block, filters[1], self.layers[1], stride=2) 31 | self.encoder3 = self._make_layer(block, filters[2], self.layers[2], stride=2) 32 | self.encoder4 = self._make_layer(block, filters[3], self.layers[3], stride=2) 33 | self.avgpool = nn.AvgPool2d(7) 34 | 35 | # Decoder 36 | self.decoder4 = linknetUp(filters[3], filters[2]) 37 | self.decoder4 = linknetUp(filters[2], filters[1]) 38 | self.decoder4 = linknetUp(filters[1], filters[0]) 39 | self.decoder4 = linknetUp(filters[0], filters[0]) 40 | 41 | # Final Classifier 42 | self.finaldeconvbnrelu1 = nn.Sequential( 43 | nn.ConvTranspose2d(filters[0], 32 / feature_scale, 3, 2, 1), 44 | nn.BatchNorm2d(32 / feature_scale), 45 | nn.ReLU(inplace=True), 46 | ) 47 | self.finalconvbnrelu2 = conv2DBatchNormRelu( 48 | in_channels=32 / feature_scale, 49 | k_size=3, 50 | n_filters=32 / feature_scale, 51 | padding=1, 52 | stride=1, 53 | ) 54 | self.finalconv3 = nn.Conv2d(32 / feature_scale, n_classes, 2, 2, 0) 55 | 56 | def _make_layer(self, block, planes, blocks, stride=1): 57 | downsample = None 58 | if stride != 1 or self.inplanes != planes * block.expansion: 59 | downsample = nn.Sequential( 60 | nn.Conv2d( 61 | self.inplanes, 62 | planes * block.expansion, 63 | kernel_size=1, 64 | stride=stride, 65 | bias=False, 66 | ), 67 | nn.BatchNorm2d(planes * block.expansion), 68 | ) 69 | layers = [] 70 | layers.append(block(self.inplanes, planes, stride, downsample)) 71 | self.inplanes = planes * block.expansion 72 | for i in range(1, blocks): 73 | layers.append(block(self.inplanes, planes)) 74 | return nn.Sequential(*layers) 75 | 76 | def forward(self, x): 77 | # Encoder 78 | x = self.convbnrelu1(x) 79 | x = self.maxpool(x) 80 | 81 | e1 = self.encoder1(x) 82 | e2 = self.encoder2(e1) 83 | e3 = self.encoder3(e2) 84 | e4 = self.encoder4(e3) 85 | 86 | # Decoder with Skip Connections 87 | d4 = self.decoder4(e4) 88 | d4 += e3 89 | d3 = self.decoder3(d4) 90 | d3 += e2 91 | d2 = self.decoder2(d3) 92 | d2 += e1 93 | d1 = self.decoder1(d2) 94 | 95 | # Final Classification 96 | f1 = self.finaldeconvbnrelu1(d1) 97 | f2 = self.finalconvbnrelu2(f1) 98 | f3 = self.finalconv3(f2) 99 | 100 | return f3 101 | -------------------------------------------------------------------------------- /ptsemseg/models/pspnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import Variable 7 | 8 | from ptsemseg import caffe_pb2 9 | from ptsemseg.models.utils import conv2DBatchNormRelu, residualBlockPSP, pyramidPooling 10 | from ptsemseg.loss.loss import multi_scale_cross_entropy2d 11 | 12 | pspnet_specs = { 13 | "pascal": {"n_classes": 21, "input_size": (473, 473), "block_config": [3, 4, 23, 3]}, 14 | "cityscapes": {"n_classes": 19, "input_size": (713, 713), "block_config": [3, 4, 23, 3]}, 15 | "ade20k": {"n_classes": 150, "input_size": (473, 473), "block_config": [3, 4, 6, 3]}, 16 | } 17 | 18 | 19 | class pspnet(nn.Module): 20 | 21 | """ 22 | Pyramid Scene Parsing Network 23 | URL: https://arxiv.org/abs/1612.01105 24 | 25 | References: 26 | 1) Original Author's code: https://github.com/hszhao/PSPNet 27 | 2) Chainer implementation by @mitmul: https://github.com/mitmul/chainer-pspnet 28 | 3) TensorFlow implementation by @hellochick: https://github.com/hellochick/PSPNet-tensorflow 29 | 30 | Visualization: 31 | http://dgschwend.github.io/netscope/#/gist/6bfb59e6a3cfcb4e2bb8d47f827c2928 32 | 33 | """ 34 | 35 | def __init__( 36 | self, n_classes=21, block_config=[3, 4, 23, 3], input_size=(473, 473), version=None 37 | ): 38 | 39 | super(pspnet, self).__init__() 40 | 41 | self.block_config = ( 42 | pspnet_specs[version]["block_config"] if version is not None else block_config 43 | ) 44 | self.n_classes = pspnet_specs[version]["n_classes"] if version is not None else n_classes 45 | self.input_size = pspnet_specs[version]["input_size"] if version is not None else input_size 46 | 47 | # Encoder 48 | self.convbnrelu1_1 = conv2DBatchNormRelu( 49 | in_channels=3, k_size=3, n_filters=64, padding=1, stride=2, bias=False 50 | ) 51 | self.convbnrelu1_2 = conv2DBatchNormRelu( 52 | in_channels=64, k_size=3, n_filters=64, padding=1, stride=1, bias=False 53 | ) 54 | self.convbnrelu1_3 = conv2DBatchNormRelu( 55 | in_channels=64, k_size=3, n_filters=128, padding=1, stride=1, bias=False 56 | ) 57 | 58 | # Vanilla Residual Blocks 59 | self.res_block2 = residualBlockPSP(self.block_config[0], 128, 64, 256, 1, 1) 60 | self.res_block3 = residualBlockPSP(self.block_config[1], 256, 128, 512, 2, 1) 61 | 62 | # Dilated Residual Blocks 63 | self.res_block4 = residualBlockPSP(self.block_config[2], 512, 256, 1024, 1, 2) 64 | self.res_block5 = residualBlockPSP(self.block_config[3], 1024, 512, 2048, 1, 4) 65 | 66 | # Pyramid Pooling Module 67 | self.pyramid_pooling = pyramidPooling(2048, [6, 3, 2, 1]) 68 | 69 | # Final conv layers 70 | self.cbr_final = conv2DBatchNormRelu(4096, 512, 3, 1, 1, False) 71 | self.dropout = nn.Dropout2d(p=0.1, inplace=False) 72 | self.classification = nn.Conv2d(512, self.n_classes, 1, 1, 0) 73 | 74 | # Auxiliary layers for training 75 | self.convbnrelu4_aux = conv2DBatchNormRelu( 76 | in_channels=1024, k_size=3, n_filters=256, padding=1, stride=1, bias=False 77 | ) 78 | self.aux_cls = nn.Conv2d(256, self.n_classes, 1, 1, 0) 79 | 80 | # Define auxiliary loss function 81 | self.loss = multi_scale_cross_entropy2d 82 | 83 | def forward(self, x): 84 | inp_shape = x.shape[2:] 85 | 86 | # H, W -> H/2, W/2 87 | x = self.convbnrelu1_1(x) 88 | x = self.convbnrelu1_2(x) 89 | x = self.convbnrelu1_3(x) 90 | 91 | # H/2, W/2 -> H/4, W/4 92 | x = F.max_pool2d(x, 3, 2, 1) 93 | 94 | # H/4, W/4 -> H/8, W/8 95 | x = self.res_block2(x) 96 | x = self.res_block3(x) 97 | x = self.res_block4(x) 98 | 99 | # Auxiliary layers for training 100 | if self.training: 101 | x_aux = self.convbnrelu4_aux(x) 102 | x_aux = self.dropout(x_aux) 103 | x_aux = self.aux_cls(x_aux) 104 | 105 | x = self.res_block5(x) 106 | 107 | x = self.pyramid_pooling(x) 108 | 109 | x = self.cbr_final(x) 110 | x = self.dropout(x) 111 | 112 | x = self.classification(x) 113 | x = F.interpolate(x, size=inp_shape, mode="bilinear", align_corners=True) 114 | 115 | if self.training: 116 | return (x, x_aux) 117 | else: # eval mode 118 | return x 119 | 120 | def load_pretrained_model(self, model_path): 121 | """ 122 | Load weights from caffemodel w/o caffe dependency 123 | and plug them in corresponding modules 124 | """ 125 | # My eyes and my heart both hurt when writing this method 126 | 127 | # Only care about layer_types that have trainable parameters 128 | ltypes = ["BNData", "ConvolutionData", "HoleConvolutionData"] 129 | 130 | def _get_layer_params(layer, ltype): 131 | 132 | if ltype == "BNData": 133 | gamma = np.array(layer.blobs[0].data) 134 | beta = np.array(layer.blobs[1].data) 135 | mean = np.array(layer.blobs[2].data) 136 | var = np.array(layer.blobs[3].data) 137 | return [mean, var, gamma, beta] 138 | 139 | elif ltype in ["ConvolutionData", "HoleConvolutionData"]: 140 | is_bias = layer.convolution_param.bias_term 141 | weights = np.array(layer.blobs[0].data) 142 | bias = [] 143 | if is_bias: 144 | bias = np.array(layer.blobs[1].data) 145 | return [weights, bias] 146 | 147 | elif ltype == "InnerProduct": 148 | raise Exception("Fully connected layers {}, not supported".format(ltype)) 149 | 150 | else: 151 | raise Exception("Unkown layer type {}".format(ltype)) 152 | 153 | net = caffe_pb2.NetParameter() 154 | with open(model_path, "rb") as model_file: 155 | net.MergeFromString(model_file.read()) 156 | 157 | # dict formatted as -> key: :: value: 158 | layer_types = {} 159 | # dict formatted as -> key: :: value:[] 160 | layer_params = {} 161 | 162 | for l in net.layer: 163 | lname = l.name 164 | ltype = l.type 165 | if ltype in ltypes: 166 | print("Processing layer {}".format(lname)) 167 | layer_types[lname] = ltype 168 | layer_params[lname] = _get_layer_params(l, ltype) 169 | 170 | # Set affine=False for all batchnorm modules 171 | def _no_affine_bn(module=None): 172 | if isinstance(module, nn.BatchNorm2d): 173 | module.affine = False 174 | 175 | if len([m for m in module.children()]) > 0: 176 | for child in module.children(): 177 | _no_affine_bn(child) 178 | 179 | # _no_affine_bn(self) 180 | 181 | def _transfer_conv(layer_name, module): 182 | weights, bias = layer_params[layer_name] 183 | w_shape = np.array(module.weight.size()) 184 | 185 | print( 186 | "CONV {}: Original {} and trans weights {}".format( 187 | layer_name, w_shape, weights.shape 188 | ) 189 | ) 190 | 191 | module.weight.data.copy_(torch.from_numpy(weights).view_as(module.weight)) 192 | 193 | if len(bias) != 0: 194 | b_shape = np.array(module.bias.size()) 195 | print( 196 | "CONV {}: Original {} and trans bias {}".format(layer_name, b_shape, bias.shape) 197 | ) 198 | module.bias.data.copy_(torch.from_numpy(bias).view_as(module.bias)) 199 | 200 | def _transfer_conv_bn(conv_layer_name, mother_module): 201 | conv_module = mother_module[0] 202 | bn_module = mother_module[1] 203 | 204 | _transfer_conv(conv_layer_name, conv_module) 205 | 206 | mean, var, gamma, beta = layer_params[conv_layer_name + "/bn"] 207 | print( 208 | "BN {}: Original {} and trans weights {}".format( 209 | conv_layer_name, bn_module.running_mean.size(), mean.shape 210 | ) 211 | ) 212 | bn_module.running_mean.copy_(torch.from_numpy(mean).view_as(bn_module.running_mean)) 213 | bn_module.running_var.copy_(torch.from_numpy(var).view_as(bn_module.running_var)) 214 | bn_module.weight.data.copy_(torch.from_numpy(gamma).view_as(bn_module.weight)) 215 | bn_module.bias.data.copy_(torch.from_numpy(beta).view_as(bn_module.bias)) 216 | 217 | def _transfer_residual(prefix, block): 218 | block_module, n_layers = block[0], block[1] 219 | 220 | bottleneck = block_module.layers[0] 221 | bottleneck_conv_bn_dic = { 222 | prefix + "_1_1x1_reduce": bottleneck.cbr1.cbr_unit, 223 | prefix + "_1_3x3": bottleneck.cbr2.cbr_unit, 224 | prefix + "_1_1x1_proj": bottleneck.cb4.cb_unit, 225 | prefix + "_1_1x1_increase": bottleneck.cb3.cb_unit, 226 | } 227 | 228 | for k, v in bottleneck_conv_bn_dic.items(): 229 | _transfer_conv_bn(k, v) 230 | 231 | for layer_idx in range(2, n_layers + 1): 232 | residual_layer = block_module.layers[layer_idx - 1] 233 | residual_conv_bn_dic = { 234 | "_".join( 235 | map(str, [prefix, layer_idx, "1x1_reduce"]) 236 | ): residual_layer.cbr1.cbr_unit, 237 | "_".join(map(str, [prefix, layer_idx, "3x3"])): residual_layer.cbr2.cbr_unit, 238 | "_".join( 239 | map(str, [prefix, layer_idx, "1x1_increase"]) 240 | ): residual_layer.cb3.cb_unit, 241 | } 242 | 243 | for k, v in residual_conv_bn_dic.items(): 244 | _transfer_conv_bn(k, v) 245 | 246 | convbn_layer_mapping = { 247 | "conv1_1_3x3_s2": self.convbnrelu1_1.cbr_unit, 248 | "conv1_2_3x3": self.convbnrelu1_2.cbr_unit, 249 | "conv1_3_3x3": self.convbnrelu1_3.cbr_unit, 250 | "conv5_3_pool6_conv": self.pyramid_pooling.paths[0].cbr_unit, 251 | "conv5_3_pool3_conv": self.pyramid_pooling.paths[1].cbr_unit, 252 | "conv5_3_pool2_conv": self.pyramid_pooling.paths[2].cbr_unit, 253 | "conv5_3_pool1_conv": self.pyramid_pooling.paths[3].cbr_unit, 254 | "conv5_4": self.cbr_final.cbr_unit, 255 | "conv4_" + str(self.block_config[2] + 1): self.convbnrelu4_aux.cbr_unit, 256 | } # Auxiliary layers for training 257 | 258 | residual_layers = { 259 | "conv2": [self.res_block2, self.block_config[0]], 260 | "conv3": [self.res_block3, self.block_config[1]], 261 | "conv4": [self.res_block4, self.block_config[2]], 262 | "conv5": [self.res_block5, self.block_config[3]], 263 | } 264 | 265 | # Transfer weights for all non-residual conv+bn layers 266 | for k, v in convbn_layer_mapping.items(): 267 | _transfer_conv_bn(k, v) 268 | 269 | # Transfer weights for final non-bn conv layer 270 | _transfer_conv("conv6", self.classification) 271 | _transfer_conv("conv6_1", self.aux_cls) 272 | 273 | # Transfer weights for all residual layers 274 | for k, v in residual_layers.items(): 275 | _transfer_residual(k, v) 276 | 277 | def tile_predict(self, imgs, include_flip_mode=True): 278 | """ 279 | Predict by takin overlapping tiles from the image. 280 | 281 | Strides are adaptively computed from the imgs shape 282 | and input size 283 | 284 | :param imgs: torch.Tensor with shape [N, C, H, W] in BGR format 285 | :param side: int with side length of model input 286 | :param n_classes: int with number of classes in seg output. 287 | """ 288 | 289 | side_x, side_y = self.input_size 290 | n_classes = self.n_classes 291 | n_samples, c, h, w = imgs.shape 292 | # n = int(max(h,w) / float(side) + 1) 293 | n_x = int(h / float(side_x) + 1) 294 | n_y = int(w / float(side_y) + 1) 295 | stride_x = (h - side_x) / float(n_x) 296 | stride_y = (w - side_y) / float(n_y) 297 | 298 | x_ends = [[int(i * stride_x), int(i * stride_x) + side_x] for i in range(n_x + 1)] 299 | y_ends = [[int(i * stride_y), int(i * stride_y) + side_y] for i in range(n_y + 1)] 300 | 301 | pred = np.zeros([n_samples, n_classes, h, w]) 302 | count = np.zeros([h, w]) 303 | 304 | slice_count = 0 305 | for sx, ex in x_ends: 306 | for sy, ey in y_ends: 307 | slice_count += 1 308 | 309 | imgs_slice = imgs[:, :, sx:ex, sy:ey] 310 | if include_flip_mode: 311 | imgs_slice_flip = torch.from_numpy( 312 | np.copy(imgs_slice.cpu().numpy()[:, :, :, ::-1]) 313 | ).float() 314 | 315 | is_model_on_cuda = next(self.parameters()).is_cuda 316 | 317 | inp = Variable(imgs_slice, volatile=True) 318 | if include_flip_mode: 319 | flp = Variable(imgs_slice_flip, volatile=True) 320 | 321 | if is_model_on_cuda: 322 | inp = inp.cuda() 323 | if include_flip_mode: 324 | flp = flp.cuda() 325 | 326 | psub1 = F.softmax(self.forward(inp), dim=1).data.cpu().numpy() 327 | if include_flip_mode: 328 | psub2 = F.softmax(self.forward(flp), dim=1).data.cpu().numpy() 329 | psub = (psub1 + psub2[:, :, :, ::-1]) / 2.0 330 | else: 331 | psub = psub1 332 | 333 | pred[:, :, sx:ex, sy:ey] = psub 334 | count[sx:ex, sy:ey] += 1.0 335 | 336 | score = (pred / count[None, None, ...]).astype(np.float32) 337 | return score / np.expand_dims(score.sum(axis=1), axis=1) 338 | 339 | 340 | # For Testing Purposes only 341 | if __name__ == "__main__": 342 | cd = 0 343 | import os 344 | import scipy.misc as m 345 | from ptsemseg.loader.cityscapes_loader import cityscapesLoader as cl 346 | 347 | psp = pspnet(version="cityscapes") 348 | 349 | # Just need to do this one time 350 | caffemodel_dir_path = "PATH_TO_PSPNET_DIR/evaluation/model" 351 | psp.load_pretrained_model( 352 | model_path=os.path.join(caffemodel_dir_path, "pspnet101_cityscapes.caffemodel") 353 | ) 354 | # psp.load_pretrained_model(model_path=os.path.join(caffemodel_dir_path, 355 | # 'pspnet50_ADE20K.caffemodel')) 356 | # psp.load_pretrained_model(model_path=os.path.join(caffemodel_dir_path, 357 | # 'pspnet101_VOC2012.caffemodel')) 358 | # 359 | # psp.load_state_dict(torch.load('psp.pth')) 360 | 361 | psp.float() 362 | psp.cuda(cd) 363 | psp.eval() 364 | 365 | dataset_root_dir = "PATH_TO_CITYSCAPES_DIR" 366 | dst = cl(root=dataset_root_dir) 367 | img = m.imread( 368 | os.path.join( 369 | dataset_root_dir, 370 | "leftImg8bit/demoVideo/stuttgart_00/stuttgart_00_000000_000010_leftImg8bit.png", 371 | ) 372 | ) 373 | m.imsave("cropped.png", img) 374 | orig_size = img.shape[:-1] 375 | img = img.transpose(2, 0, 1) 376 | img = img.astype(np.float64) 377 | img -= np.array([123.68, 116.779, 103.939])[:, None, None] 378 | img = np.copy(img[::-1, :, :]) 379 | img = torch.from_numpy(img).float() # convert to torch tensor 380 | img = img.unsqueeze(0) 381 | 382 | out = psp.tile_predict(img) 383 | pred = np.argmax(out, axis=1)[0] 384 | decoded = dst.decode_segmap(pred) 385 | m.imsave("cityscapes_sttutgart_tiled.png", decoded) 386 | # m.imsave('cityscapes_sttutgart_tiled.png', pred) 387 | 388 | checkpoints_dir_path = "checkpoints" 389 | if not os.path.exists(checkpoints_dir_path): 390 | os.mkdir(checkpoints_dir_path) 391 | psp = torch.nn.DataParallel( 392 | psp, device_ids=range(torch.cuda.device_count()) 393 | ) # append `module.` 394 | state = {"model_state": psp.state_dict()} 395 | torch.save(state, os.path.join(checkpoints_dir_path, "pspnet_101_cityscapes.pth")) 396 | # torch.save(state, os.path.join(checkpoints_dir_path, "pspnet_50_ade20k.pth")) 397 | # torch.save(state, os.path.join(checkpoints_dir_path, "pspnet_101_pascalvoc.pth")) 398 | print("Output Shape {} \t Input Shape {}".format(out.shape, img.shape)) 399 | -------------------------------------------------------------------------------- /ptsemseg/models/refinenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class refinenet(nn.Module): 5 | """ 6 | RefineNet: Multi-Path Refinement Networks for High-Resolution Semantic Segmentation 7 | URL: https://arxiv.org/abs/1611.06612 8 | 9 | References: 10 | 1) Original Author's MATLAB code: https://github.com/guosheng/refinenet 11 | 2) TF implementation by @eragonruan: https://github.com/eragonruan/refinenet-image-segmentation 12 | """ 13 | 14 | def __init__(self, n_classes=21): 15 | super(refinenet, self).__init__() 16 | self.n_classes = n_classes 17 | 18 | def forward(self, x): 19 | pass 20 | -------------------------------------------------------------------------------- /ptsemseg/models/segnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ptsemseg.models.utils import segnetDown2, segnetDown3, segnetUp2, segnetUp3 4 | 5 | 6 | class segnet(nn.Module): 7 | def __init__(self, n_classes=21, in_channels=3, is_unpooling=True): 8 | super(segnet, self).__init__() 9 | 10 | self.in_channels = in_channels 11 | self.is_unpooling = is_unpooling 12 | 13 | self.down1 = segnetDown2(self.in_channels, 64) 14 | self.down2 = segnetDown2(64, 128) 15 | self.down3 = segnetDown3(128, 256) 16 | self.down4 = segnetDown3(256, 512) 17 | self.down5 = segnetDown3(512, 512) 18 | 19 | self.up5 = segnetUp3(512, 512) 20 | self.up4 = segnetUp3(512, 256) 21 | self.up3 = segnetUp3(256, 128) 22 | self.up2 = segnetUp2(128, 64) 23 | self.up1 = segnetUp2(64, n_classes) 24 | 25 | def forward(self, inputs): 26 | 27 | down1, indices_1, unpool_shape1 = self.down1(inputs) 28 | down2, indices_2, unpool_shape2 = self.down2(down1) 29 | down3, indices_3, unpool_shape3 = self.down3(down2) 30 | down4, indices_4, unpool_shape4 = self.down4(down3) 31 | down5, indices_5, unpool_shape5 = self.down5(down4) 32 | 33 | up5 = self.up5(down5, indices_5, unpool_shape5) 34 | up4 = self.up4(up5, indices_4, unpool_shape4) 35 | up3 = self.up3(up4, indices_3, unpool_shape3) 36 | up2 = self.up2(up3, indices_2, unpool_shape2) 37 | up1 = self.up1(up2, indices_1, unpool_shape1) 38 | 39 | return up1 40 | 41 | def init_vgg16_params(self, vgg16): 42 | blocks = [self.down1, self.down2, self.down3, self.down4, self.down5] 43 | 44 | features = list(vgg16.features.children()) 45 | 46 | vgg_layers = [] 47 | for _layer in features: 48 | if isinstance(_layer, nn.Conv2d): 49 | vgg_layers.append(_layer) 50 | 51 | merged_layers = [] 52 | for idx, conv_block in enumerate(blocks): 53 | if idx < 2: 54 | units = [conv_block.conv1.cbr_unit, conv_block.conv2.cbr_unit] 55 | else: 56 | units = [ 57 | conv_block.conv1.cbr_unit, 58 | conv_block.conv2.cbr_unit, 59 | conv_block.conv3.cbr_unit, 60 | ] 61 | for _unit in units: 62 | for _layer in _unit: 63 | if isinstance(_layer, nn.Conv2d): 64 | merged_layers.append(_layer) 65 | 66 | assert len(vgg_layers) == len(merged_layers) 67 | 68 | for l1, l2 in zip(vgg_layers, merged_layers): 69 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 70 | assert l1.weight.size() == l2.weight.size() 71 | assert l1.bias.size() == l2.bias.size() 72 | l2.weight.data = l1.weight.data 73 | l2.bias.data = l1.bias.data 74 | -------------------------------------------------------------------------------- /ptsemseg/models/unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ptsemseg.models.utils import unetConv2, unetUp 4 | 5 | 6 | class unet(nn.Module): 7 | def __init__( 8 | self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True 9 | ): 10 | super(unet, self).__init__() 11 | self.is_deconv = is_deconv 12 | self.in_channels = in_channels 13 | self.is_batchnorm = is_batchnorm 14 | self.feature_scale = feature_scale 15 | 16 | filters = [64, 128, 256, 512, 1024] 17 | filters = [int(x / self.feature_scale) for x in filters] 18 | 19 | # downsampling 20 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) 21 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 22 | 23 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 24 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 25 | 26 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 27 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 28 | 29 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 30 | self.maxpool4 = nn.MaxPool2d(kernel_size=2) 31 | 32 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) 33 | 34 | # upsampling 35 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 36 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 37 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 38 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) 39 | 40 | # final conv (without any concat) 41 | self.final = nn.Conv2d(filters[0], n_classes, 1) 42 | 43 | def forward(self, inputs): 44 | conv1 = self.conv1(inputs) 45 | maxpool1 = self.maxpool1(conv1) 46 | 47 | conv2 = self.conv2(maxpool1) 48 | maxpool2 = self.maxpool2(conv2) 49 | 50 | conv3 = self.conv3(maxpool2) 51 | maxpool3 = self.maxpool3(conv3) 52 | 53 | conv4 = self.conv4(maxpool3) 54 | maxpool4 = self.maxpool4(conv4) 55 | 56 | center = self.center(maxpool4) 57 | up4 = self.up_concat4(conv4, center) 58 | up3 = self.up_concat3(conv3, up4) 59 | up2 = self.up_concat2(conv2, up3) 60 | up1 = self.up_concat1(conv1, up2) 61 | 62 | final = self.final(up1) 63 | 64 | return final 65 | -------------------------------------------------------------------------------- /ptsemseg/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import Variable 7 | 8 | 9 | class conv2DBatchNorm(nn.Module): 10 | def __init__( 11 | self, 12 | in_channels, 13 | n_filters, 14 | k_size, 15 | stride, 16 | padding, 17 | bias=True, 18 | dilation=1, 19 | is_batchnorm=True, 20 | ): 21 | super(conv2DBatchNorm, self).__init__() 22 | 23 | conv_mod = nn.Conv2d( 24 | int(in_channels), 25 | int(n_filters), 26 | kernel_size=k_size, 27 | padding=padding, 28 | stride=stride, 29 | bias=bias, 30 | dilation=dilation, 31 | ) 32 | 33 | if is_batchnorm: 34 | self.cb_unit = nn.Sequential(conv_mod, nn.BatchNorm2d(int(n_filters))) 35 | else: 36 | self.cb_unit = nn.Sequential(conv_mod) 37 | 38 | def forward(self, inputs): 39 | outputs = self.cb_unit(inputs) 40 | return outputs 41 | 42 | 43 | class conv2DGroupNorm(nn.Module): 44 | def __init__( 45 | self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, n_groups=16 46 | ): 47 | super(conv2DGroupNorm, self).__init__() 48 | 49 | conv_mod = nn.Conv2d( 50 | int(in_channels), 51 | int(n_filters), 52 | kernel_size=k_size, 53 | padding=padding, 54 | stride=stride, 55 | bias=bias, 56 | dilation=dilation, 57 | ) 58 | 59 | self.cg_unit = nn.Sequential(conv_mod, nn.GroupNorm(n_groups, int(n_filters))) 60 | 61 | def forward(self, inputs): 62 | outputs = self.cg_unit(inputs) 63 | return outputs 64 | 65 | 66 | class deconv2DBatchNorm(nn.Module): 67 | def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): 68 | super(deconv2DBatchNorm, self).__init__() 69 | 70 | self.dcb_unit = nn.Sequential( 71 | nn.ConvTranspose2d( 72 | int(in_channels), 73 | int(n_filters), 74 | kernel_size=k_size, 75 | padding=padding, 76 | stride=stride, 77 | bias=bias, 78 | ), 79 | nn.BatchNorm2d(int(n_filters)), 80 | ) 81 | 82 | def forward(self, inputs): 83 | outputs = self.dcb_unit(inputs) 84 | return outputs 85 | 86 | 87 | class conv2DBatchNormRelu(nn.Module): 88 | def __init__( 89 | self, 90 | in_channels, 91 | n_filters, 92 | k_size, 93 | stride, 94 | padding, 95 | bias=True, 96 | dilation=1, 97 | is_batchnorm=True, 98 | ): 99 | super(conv2DBatchNormRelu, self).__init__() 100 | 101 | conv_mod = nn.Conv2d( 102 | int(in_channels), 103 | int(n_filters), 104 | kernel_size=k_size, 105 | padding=padding, 106 | stride=stride, 107 | bias=bias, 108 | dilation=dilation, 109 | ) 110 | 111 | if is_batchnorm: 112 | self.cbr_unit = nn.Sequential( 113 | conv_mod, nn.BatchNorm2d(int(n_filters)), nn.ReLU(inplace=True) 114 | ) 115 | else: 116 | self.cbr_unit = nn.Sequential(conv_mod, nn.ReLU(inplace=True)) 117 | 118 | def forward(self, inputs): 119 | outputs = self.cbr_unit(inputs) 120 | return outputs 121 | 122 | 123 | class conv2DGroupNormRelu(nn.Module): 124 | def __init__( 125 | self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, n_groups=16 126 | ): 127 | super(conv2DGroupNormRelu, self).__init__() 128 | 129 | conv_mod = nn.Conv2d( 130 | int(in_channels), 131 | int(n_filters), 132 | kernel_size=k_size, 133 | padding=padding, 134 | stride=stride, 135 | bias=bias, 136 | dilation=dilation, 137 | ) 138 | 139 | self.cgr_unit = nn.Sequential( 140 | conv_mod, nn.GroupNorm(n_groups, int(n_filters)), nn.ReLU(inplace=True) 141 | ) 142 | 143 | def forward(self, inputs): 144 | outputs = self.cgr_unit(inputs) 145 | return outputs 146 | 147 | 148 | class deconv2DBatchNormRelu(nn.Module): 149 | def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): 150 | super(deconv2DBatchNormRelu, self).__init__() 151 | 152 | self.dcbr_unit = nn.Sequential( 153 | nn.ConvTranspose2d( 154 | int(in_channels), 155 | int(n_filters), 156 | kernel_size=k_size, 157 | padding=padding, 158 | stride=stride, 159 | bias=bias, 160 | ), 161 | nn.BatchNorm2d(int(n_filters)), 162 | nn.ReLU(inplace=True), 163 | ) 164 | 165 | def forward(self, inputs): 166 | outputs = self.dcbr_unit(inputs) 167 | return outputs 168 | 169 | 170 | class unetConv2(nn.Module): 171 | def __init__(self, in_size, out_size, is_batchnorm): 172 | super(unetConv2, self).__init__() 173 | 174 | if is_batchnorm: 175 | self.conv1 = nn.Sequential( 176 | nn.Conv2d(in_size, out_size, 3, 1, 0), nn.BatchNorm2d(out_size), nn.ReLU() 177 | ) 178 | self.conv2 = nn.Sequential( 179 | nn.Conv2d(out_size, out_size, 3, 1, 0), nn.BatchNorm2d(out_size), nn.ReLU() 180 | ) 181 | else: 182 | self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 0), nn.ReLU()) 183 | self.conv2 = nn.Sequential(nn.Conv2d(out_size, out_size, 3, 1, 0), nn.ReLU()) 184 | 185 | def forward(self, inputs): 186 | outputs = self.conv1(inputs) 187 | outputs = self.conv2(outputs) 188 | return outputs 189 | 190 | 191 | class unetUp(nn.Module): 192 | def __init__(self, in_size, out_size, is_deconv): 193 | super(unetUp, self).__init__() 194 | self.conv = unetConv2(in_size, out_size, False) 195 | if is_deconv: 196 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) 197 | else: 198 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 199 | 200 | def forward(self, inputs1, inputs2): 201 | outputs2 = self.up(inputs2) 202 | offset = outputs2.size()[2] - inputs1.size()[2] 203 | padding = 2 * [offset // 2, offset // 2] 204 | outputs1 = F.pad(inputs1, padding) 205 | return self.conv(torch.cat([outputs1, outputs2], 1)) 206 | 207 | 208 | class segnetDown2(nn.Module): 209 | def __init__(self, in_size, out_size): 210 | super(segnetDown2, self).__init__() 211 | self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 212 | self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) 213 | self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) 214 | 215 | def forward(self, inputs): 216 | outputs = self.conv1(inputs) 217 | outputs = self.conv2(outputs) 218 | unpooled_shape = outputs.size() 219 | outputs, indices = self.maxpool_with_argmax(outputs) 220 | return outputs, indices, unpooled_shape 221 | 222 | 223 | class segnetDown3(nn.Module): 224 | def __init__(self, in_size, out_size): 225 | super(segnetDown3, self).__init__() 226 | self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 227 | self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) 228 | self.conv3 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) 229 | self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) 230 | 231 | def forward(self, inputs): 232 | outputs = self.conv1(inputs) 233 | outputs = self.conv2(outputs) 234 | outputs = self.conv3(outputs) 235 | unpooled_shape = outputs.size() 236 | outputs, indices = self.maxpool_with_argmax(outputs) 237 | return outputs, indices, unpooled_shape 238 | 239 | 240 | class segnetUp2(nn.Module): 241 | def __init__(self, in_size, out_size): 242 | super(segnetUp2, self).__init__() 243 | self.unpool = nn.MaxUnpool2d(2, 2) 244 | self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) 245 | self.conv2 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 246 | 247 | def forward(self, inputs, indices, output_shape): 248 | outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape) 249 | outputs = self.conv1(outputs) 250 | outputs = self.conv2(outputs) 251 | return outputs 252 | 253 | 254 | class segnetUp3(nn.Module): 255 | def __init__(self, in_size, out_size): 256 | super(segnetUp3, self).__init__() 257 | self.unpool = nn.MaxUnpool2d(2, 2) 258 | self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) 259 | self.conv2 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) 260 | self.conv3 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 261 | 262 | def forward(self, inputs, indices, output_shape): 263 | outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape) 264 | outputs = self.conv1(outputs) 265 | outputs = self.conv2(outputs) 266 | outputs = self.conv3(outputs) 267 | return outputs 268 | 269 | 270 | class residualBlock(nn.Module): 271 | expansion = 1 272 | 273 | def __init__(self, in_channels, n_filters, stride=1, downsample=None): 274 | super(residualBlock, self).__init__() 275 | 276 | self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, 1, bias=False) 277 | self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, bias=False) 278 | self.downsample = downsample 279 | self.stride = stride 280 | self.relu = nn.ReLU(inplace=True) 281 | 282 | def forward(self, x): 283 | residual = x 284 | 285 | out = self.convbnrelu1(x) 286 | out = self.convbn2(out) 287 | 288 | if self.downsample is not None: 289 | residual = self.downsample(x) 290 | 291 | out += residual 292 | out = self.relu(out) 293 | return out 294 | 295 | 296 | class residualBottleneck(nn.Module): 297 | expansion = 4 298 | 299 | def __init__(self, in_channels, n_filters, stride=1, downsample=None): 300 | super(residualBottleneck, self).__init__() 301 | self.convbn1 = nn.Conv2DBatchNorm(in_channels, n_filters, k_size=1, bias=False) 302 | self.convbn2 = nn.Conv2DBatchNorm( 303 | n_filters, n_filters, k_size=3, padding=1, stride=stride, bias=False 304 | ) 305 | self.convbn3 = nn.Conv2DBatchNorm(n_filters, n_filters * 4, k_size=1, bias=False) 306 | self.relu = nn.ReLU(inplace=True) 307 | self.downsample = downsample 308 | self.stride = stride 309 | 310 | def forward(self, x): 311 | residual = x 312 | 313 | out = self.convbn1(x) 314 | out = self.convbn2(out) 315 | out = self.convbn3(out) 316 | 317 | if self.downsample is not None: 318 | residual = self.downsample(x) 319 | 320 | out += residual 321 | out = self.relu(out) 322 | 323 | return out 324 | 325 | 326 | class linknetUp(nn.Module): 327 | def __init__(self, in_channels, n_filters): 328 | super(linknetUp, self).__init__() 329 | 330 | # B, 2C, H, W -> B, C/2, H, W 331 | self.convbnrelu1 = conv2DBatchNormRelu( 332 | in_channels, n_filters / 2, k_size=1, stride=1, padding=1 333 | ) 334 | 335 | # B, C/2, H, W -> B, C/2, H, W 336 | self.deconvbnrelu2 = nn.deconv2DBatchNormRelu( 337 | n_filters / 2, n_filters / 2, k_size=3, stride=2, padding=0 338 | ) 339 | 340 | # B, C/2, H, W -> B, C, H, W 341 | self.convbnrelu3 = conv2DBatchNormRelu( 342 | n_filters / 2, n_filters, k_size=1, stride=1, padding=1 343 | ) 344 | 345 | def forward(self, x): 346 | x = self.convbnrelu1(x) 347 | x = self.deconvbnrelu2(x) 348 | x = self.convbnrelu3(x) 349 | return x 350 | 351 | 352 | class FRRU(nn.Module): 353 | """ 354 | Full Resolution Residual Unit for FRRN 355 | """ 356 | 357 | def __init__(self, prev_channels, out_channels, scale, group_norm=False, n_groups=None): 358 | super(FRRU, self).__init__() 359 | self.scale = scale 360 | self.prev_channels = prev_channels 361 | self.out_channels = out_channels 362 | self.group_norm = group_norm 363 | self.n_groups = n_groups 364 | 365 | if self.group_norm: 366 | conv_unit = conv2DGroupNormRelu 367 | self.conv1 = conv_unit( 368 | prev_channels + 32, 369 | out_channels, 370 | k_size=3, 371 | stride=1, 372 | padding=1, 373 | bias=False, 374 | n_groups=self.n_groups, 375 | ) 376 | self.conv2 = conv_unit( 377 | out_channels, 378 | out_channels, 379 | k_size=3, 380 | stride=1, 381 | padding=1, 382 | bias=False, 383 | n_groups=self.n_groups, 384 | ) 385 | 386 | else: 387 | conv_unit = conv2DBatchNormRelu 388 | self.conv1 = conv_unit( 389 | prev_channels + 32, out_channels, k_size=3, stride=1, padding=1, bias=False 390 | ) 391 | self.conv2 = conv_unit( 392 | out_channels, out_channels, k_size=3, stride=1, padding=1, bias=False 393 | ) 394 | 395 | self.conv_res = nn.Conv2d(out_channels, 32, kernel_size=1, stride=1, padding=0) 396 | 397 | def forward(self, y, z): 398 | x = torch.cat([y, nn.MaxPool2d(self.scale, self.scale)(z)], dim=1) 399 | y_prime = self.conv1(x) 400 | y_prime = self.conv2(y_prime) 401 | 402 | x = self.conv_res(y_prime) 403 | upsample_size = torch.Size([_s * self.scale for _s in y_prime.shape[-2:]]) 404 | x = F.upsample(x, size=upsample_size, mode="nearest") 405 | z_prime = z + x 406 | 407 | return y_prime, z_prime 408 | 409 | 410 | class RU(nn.Module): 411 | """ 412 | Residual Unit for FRRN 413 | """ 414 | 415 | def __init__(self, channels, kernel_size=3, strides=1, group_norm=False, n_groups=None): 416 | super(RU, self).__init__() 417 | self.group_norm = group_norm 418 | self.n_groups = n_groups 419 | 420 | if self.group_norm: 421 | self.conv1 = conv2DGroupNormRelu( 422 | channels, 423 | channels, 424 | k_size=kernel_size, 425 | stride=strides, 426 | padding=1, 427 | bias=False, 428 | n_groups=self.n_groups, 429 | ) 430 | self.conv2 = conv2DGroupNorm( 431 | channels, 432 | channels, 433 | k_size=kernel_size, 434 | stride=strides, 435 | padding=1, 436 | bias=False, 437 | n_groups=self.n_groups, 438 | ) 439 | 440 | else: 441 | self.conv1 = conv2DBatchNormRelu( 442 | channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False 443 | ) 444 | self.conv2 = conv2DBatchNorm( 445 | channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False 446 | ) 447 | 448 | def forward(self, x): 449 | incoming = x 450 | x = self.conv1(x) 451 | x = self.conv2(x) 452 | return x + incoming 453 | 454 | 455 | class residualConvUnit(nn.Module): 456 | def __init__(self, channels, kernel_size=3): 457 | super(residualConvUnit, self).__init__() 458 | 459 | self.residual_conv_unit = nn.Sequential( 460 | nn.ReLU(inplace=True), 461 | nn.Conv2d(channels, channels, kernel_size=kernel_size), 462 | nn.ReLU(inplace=True), 463 | nn.Conv2d(channels, channels, kernel_size=kernel_size), 464 | ) 465 | 466 | def forward(self, x): 467 | input = x 468 | x = self.residual_conv_unit(x) 469 | return x + input 470 | 471 | 472 | class multiResolutionFusion(nn.Module): 473 | def __init__(self, channels, up_scale_high, up_scale_low, high_shape, low_shape): 474 | super(multiResolutionFusion, self).__init__() 475 | 476 | self.up_scale_high = up_scale_high 477 | self.up_scale_low = up_scale_low 478 | 479 | self.conv_high = nn.Conv2d(high_shape[1], channels, kernel_size=3) 480 | 481 | if low_shape is not None: 482 | self.conv_low = nn.Conv2d(low_shape[1], channels, kernel_size=3) 483 | 484 | def forward(self, x_high, x_low): 485 | high_upsampled = F.upsample( 486 | self.conv_high(x_high), scale_factor=self.up_scale_high, mode="bilinear" 487 | ) 488 | 489 | if x_low is None: 490 | return high_upsampled 491 | 492 | low_upsampled = F.upsample( 493 | self.conv_low(x_low), scale_factor=self.up_scale_low, mode="bilinear" 494 | ) 495 | 496 | return low_upsampled + high_upsampled 497 | 498 | 499 | class chainedResidualPooling(nn.Module): 500 | def __init__(self, channels, input_shape): 501 | super(chainedResidualPooling, self).__init__() 502 | 503 | self.chained_residual_pooling = nn.Sequential( 504 | nn.ReLU(inplace=True), 505 | nn.MaxPool2d(5, 1, 2), 506 | nn.Conv2d(input_shape[1], channels, kernel_size=3), 507 | ) 508 | 509 | def forward(self, x): 510 | input = x 511 | x = self.chained_residual_pooling(x) 512 | return x + input 513 | 514 | 515 | class pyramidPooling(nn.Module): 516 | def __init__( 517 | self, in_channels, pool_sizes, model_name="pspnet", fusion_mode="cat", is_batchnorm=True 518 | ): 519 | super(pyramidPooling, self).__init__() 520 | 521 | bias = not is_batchnorm 522 | 523 | self.paths = [] 524 | for i in range(len(pool_sizes)): 525 | self.paths.append( 526 | conv2DBatchNormRelu( 527 | in_channels, 528 | int(in_channels / len(pool_sizes)), 529 | 1, 530 | 1, 531 | 0, 532 | bias=bias, 533 | is_batchnorm=is_batchnorm, 534 | ) 535 | ) 536 | 537 | self.path_module_list = nn.ModuleList(self.paths) 538 | self.pool_sizes = pool_sizes 539 | self.model_name = model_name 540 | self.fusion_mode = fusion_mode 541 | 542 | def forward(self, x): 543 | h, w = x.shape[2:] 544 | 545 | if self.training or self.model_name != "icnet": # general settings or pspnet 546 | k_sizes = [] 547 | strides = [] 548 | for pool_size in self.pool_sizes: 549 | k_sizes.append((int(h / pool_size), int(w / pool_size))) 550 | strides.append((int(h / pool_size), int(w / pool_size))) 551 | else: # eval mode and icnet: pre-trained for 1025 x 2049 552 | k_sizes = [(8, 15), (13, 25), (17, 33), (33, 65)] 553 | strides = [(5, 10), (10, 20), (16, 32), (33, 65)] 554 | 555 | if self.fusion_mode == "cat": # pspnet: concat (including x) 556 | output_slices = [x] 557 | 558 | for i, (module, pool_size) in enumerate(zip(self.path_module_list, self.pool_sizes)): 559 | out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0) 560 | # out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size)) 561 | if self.model_name != "icnet": 562 | out = module(out) 563 | out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True) 564 | output_slices.append(out) 565 | 566 | return torch.cat(output_slices, dim=1) 567 | else: # icnet: element-wise sum (including x) 568 | pp_sum = x 569 | 570 | for i, (module, pool_size) in enumerate(zip(self.path_module_list, self.pool_sizes)): 571 | out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0) 572 | # out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size)) 573 | if self.model_name != "icnet": 574 | out = module(out) 575 | out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True) 576 | pp_sum = pp_sum + out 577 | 578 | return pp_sum 579 | 580 | 581 | class bottleNeckPSP(nn.Module): 582 | def __init__( 583 | self, in_channels, mid_channels, out_channels, stride, dilation=1, is_batchnorm=True 584 | ): 585 | super(bottleNeckPSP, self).__init__() 586 | 587 | bias = not is_batchnorm 588 | 589 | self.cbr1 = conv2DBatchNormRelu( 590 | in_channels, mid_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm 591 | ) 592 | if dilation > 1: 593 | self.cbr2 = conv2DBatchNormRelu( 594 | mid_channels, 595 | mid_channels, 596 | 3, 597 | stride=stride, 598 | padding=dilation, 599 | bias=bias, 600 | dilation=dilation, 601 | is_batchnorm=is_batchnorm, 602 | ) 603 | else: 604 | self.cbr2 = conv2DBatchNormRelu( 605 | mid_channels, 606 | mid_channels, 607 | 3, 608 | stride=stride, 609 | padding=1, 610 | bias=bias, 611 | dilation=1, 612 | is_batchnorm=is_batchnorm, 613 | ) 614 | self.cb3 = conv2DBatchNorm( 615 | mid_channels, out_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm 616 | ) 617 | self.cb4 = conv2DBatchNorm( 618 | in_channels, 619 | out_channels, 620 | 1, 621 | stride=stride, 622 | padding=0, 623 | bias=bias, 624 | is_batchnorm=is_batchnorm, 625 | ) 626 | 627 | def forward(self, x): 628 | conv = self.cb3(self.cbr2(self.cbr1(x))) 629 | residual = self.cb4(x) 630 | return F.relu(conv + residual, inplace=True) 631 | 632 | 633 | class bottleNeckIdentifyPSP(nn.Module): 634 | def __init__(self, in_channels, mid_channels, stride, dilation=1, is_batchnorm=True): 635 | super(bottleNeckIdentifyPSP, self).__init__() 636 | 637 | bias = not is_batchnorm 638 | 639 | self.cbr1 = conv2DBatchNormRelu( 640 | in_channels, mid_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm 641 | ) 642 | if dilation > 1: 643 | self.cbr2 = conv2DBatchNormRelu( 644 | mid_channels, 645 | mid_channels, 646 | 3, 647 | stride=1, 648 | padding=dilation, 649 | bias=bias, 650 | dilation=dilation, 651 | is_batchnorm=is_batchnorm, 652 | ) 653 | else: 654 | self.cbr2 = conv2DBatchNormRelu( 655 | mid_channels, 656 | mid_channels, 657 | 3, 658 | stride=1, 659 | padding=1, 660 | bias=bias, 661 | dilation=1, 662 | is_batchnorm=is_batchnorm, 663 | ) 664 | self.cb3 = conv2DBatchNorm( 665 | mid_channels, in_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm 666 | ) 667 | 668 | def forward(self, x): 669 | residual = x 670 | x = self.cb3(self.cbr2(self.cbr1(x))) 671 | return F.relu(x + residual, inplace=True) 672 | 673 | 674 | class residualBlockPSP(nn.Module): 675 | def __init__( 676 | self, 677 | n_blocks, 678 | in_channels, 679 | mid_channels, 680 | out_channels, 681 | stride, 682 | dilation=1, 683 | include_range="all", 684 | is_batchnorm=True, 685 | ): 686 | super(residualBlockPSP, self).__init__() 687 | 688 | if dilation > 1: 689 | stride = 1 690 | 691 | # residualBlockPSP = convBlockPSP + identityBlockPSPs 692 | layers = [] 693 | if include_range in ["all", "conv"]: 694 | layers.append( 695 | bottleNeckPSP( 696 | in_channels, 697 | mid_channels, 698 | out_channels, 699 | stride, 700 | dilation, 701 | is_batchnorm=is_batchnorm, 702 | ) 703 | ) 704 | if include_range in ["all", "identity"]: 705 | for i in range(n_blocks - 1): 706 | layers.append( 707 | bottleNeckIdentifyPSP( 708 | out_channels, mid_channels, stride, dilation, is_batchnorm=is_batchnorm 709 | ) 710 | ) 711 | 712 | self.layers = nn.Sequential(*layers) 713 | 714 | def forward(self, x): 715 | return self.layers(x) 716 | 717 | 718 | class cascadeFeatureFusion(nn.Module): 719 | def __init__( 720 | self, n_classes, low_in_channels, high_in_channels, out_channels, is_batchnorm=True 721 | ): 722 | super(cascadeFeatureFusion, self).__init__() 723 | 724 | bias = not is_batchnorm 725 | 726 | self.low_dilated_conv_bn = conv2DBatchNorm( 727 | low_in_channels, 728 | out_channels, 729 | 3, 730 | stride=1, 731 | padding=2, 732 | bias=bias, 733 | dilation=2, 734 | is_batchnorm=is_batchnorm, 735 | ) 736 | self.low_classifier_conv = nn.Conv2d( 737 | int(low_in_channels), 738 | int(n_classes), 739 | kernel_size=1, 740 | padding=0, 741 | stride=1, 742 | bias=True, 743 | dilation=1, 744 | ) # Train only 745 | self.high_proj_conv_bn = conv2DBatchNorm( 746 | high_in_channels, 747 | out_channels, 748 | 1, 749 | stride=1, 750 | padding=0, 751 | bias=bias, 752 | is_batchnorm=is_batchnorm, 753 | ) 754 | 755 | def forward(self, x_low, x_high): 756 | x_low_upsampled = F.interpolate( 757 | x_low, size=get_interp_size(x_low, z_factor=2), mode="bilinear", align_corners=True 758 | ) 759 | 760 | low_cls = self.low_classifier_conv(x_low_upsampled) 761 | 762 | low_fm = self.low_dilated_conv_bn(x_low_upsampled) 763 | high_fm = self.high_proj_conv_bn(x_high) 764 | high_fused_fm = F.relu(low_fm + high_fm, inplace=True) 765 | 766 | return high_fused_fm, low_cls 767 | 768 | 769 | def get_interp_size(input, s_factor=1, z_factor=1): # for caffe 770 | ori_h, ori_w = input.shape[2:] 771 | 772 | # shrink (s_factor >= 1) 773 | ori_h = (ori_h - 1) / s_factor + 1 774 | ori_w = (ori_w - 1) / s_factor + 1 775 | 776 | # zoom (z_factor >= 1) 777 | ori_h = ori_h + (ori_h - 1) * (z_factor - 1) 778 | ori_w = ori_w + (ori_w - 1) * (z_factor - 1) 779 | 780 | resize_shape = (int(ori_h), int(ori_w)) 781 | return resize_shape 782 | 783 | 784 | def interp(input, output_size, mode="bilinear"): 785 | n, c, ih, iw = input.shape 786 | oh, ow = output_size 787 | 788 | # normalize to [-1, 1] 789 | h = torch.arange(0, oh, dtype=torch.float, device=input.device) / (oh - 1) * 2 - 1 790 | w = torch.arange(0, ow, dtype=torch.float, device=input.device) / (ow - 1) * 2 - 1 791 | 792 | grid = torch.zeros(oh, ow, 2, dtype=torch.float, device=input.device) 793 | grid[:, :, 0] = w.unsqueeze(0).repeat(oh, 1) 794 | grid[:, :, 1] = h.unsqueeze(0).repeat(ow, 1).transpose(0, 1) 795 | grid = grid.unsqueeze(0).repeat(n, 1, 1, 1) # grid.shape: [n, oh, ow, 2] 796 | grid = Variable(grid) 797 | if input.is_cuda: 798 | grid = grid.cuda() 799 | 800 | return F.grid_sample(input, grid, mode=mode) 801 | 802 | 803 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 804 | """Make a 2D bilinear kernel suitable for upsampling""" 805 | factor = (kernel_size + 1) // 2 806 | if kernel_size % 2 == 1: 807 | center = factor - 1 808 | else: 809 | center = factor - 0.5 810 | og = np.ogrid[:kernel_size, :kernel_size] 811 | filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) 812 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64) 813 | weight[range(in_channels), range(out_channels), :, :] = filt 814 | return torch.from_numpy(weight).float() 815 | -------------------------------------------------------------------------------- /ptsemseg/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from torch.optim import SGD, Adam, ASGD, Adamax, Adadelta, Adagrad, RMSprop 4 | 5 | logger = logging.getLogger("ptsemseg") 6 | 7 | key2opt = { 8 | "sgd": SGD, 9 | "adam": Adam, 10 | "asgd": ASGD, 11 | "adamax": Adamax, 12 | "adadelta": Adadelta, 13 | "adagrad": Adagrad, 14 | "rmsprop": RMSprop, 15 | } 16 | 17 | 18 | def get_optimizer(cfg): 19 | if cfg["training"]["optimizer"] is None: 20 | logger.info("Using SGD optimizer") 21 | return SGD 22 | 23 | else: 24 | opt_name = cfg["training"]["optimizer"]["name"] 25 | if opt_name not in key2opt: 26 | raise NotImplementedError("Optimizer {} not implemented".format(opt_name)) 27 | 28 | logger.info("Using {} optimizer".format(opt_name)) 29 | return key2opt[opt_name] 30 | -------------------------------------------------------------------------------- /ptsemseg/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from torch.optim.lr_scheduler import MultiStepLR, ExponentialLR, CosineAnnealingLR 4 | 5 | from ptsemseg.schedulers.schedulers import WarmUpLR, ConstantLR, PolynomialLR 6 | 7 | logger = logging.getLogger("ptsemseg") 8 | 9 | key2scheduler = { 10 | "constant_lr": ConstantLR, 11 | "poly_lr": PolynomialLR, 12 | "multi_step": MultiStepLR, 13 | "cosine_annealing": CosineAnnealingLR, 14 | "exp_lr": ExponentialLR, 15 | } 16 | 17 | 18 | def get_scheduler(optimizer, scheduler_dict): 19 | if scheduler_dict is None: 20 | logger.info("Using No LR Scheduling") 21 | return ConstantLR(optimizer) 22 | 23 | s_type = scheduler_dict["name"] 24 | scheduler_dict.pop("name") 25 | 26 | logging.info("Using {} scheduler with {} params".format(s_type, scheduler_dict)) 27 | 28 | warmup_dict = {} 29 | if "warmup_iters" in scheduler_dict: 30 | # This can be done in a more pythonic way... 31 | warmup_dict["warmup_iters"] = scheduler_dict.get("warmup_iters", 100) 32 | warmup_dict["mode"] = scheduler_dict.get("warmup_mode", "linear") 33 | warmup_dict["gamma"] = scheduler_dict.get("warmup_factor", 0.2) 34 | 35 | logger.info( 36 | "Using Warmup with {} iters {} gamma and {} mode".format( 37 | warmup_dict["warmup_iters"], warmup_dict["gamma"], warmup_dict["mode"] 38 | ) 39 | ) 40 | 41 | scheduler_dict.pop("warmup_iters", None) 42 | scheduler_dict.pop("warmup_mode", None) 43 | scheduler_dict.pop("warmup_factor", None) 44 | 45 | base_scheduler = key2scheduler[s_type](optimizer, **scheduler_dict) 46 | return WarmUpLR(optimizer, base_scheduler, **warmup_dict) 47 | 48 | return key2scheduler[s_type](optimizer, **scheduler_dict) 49 | -------------------------------------------------------------------------------- /ptsemseg/schedulers/schedulers.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class ConstantLR(_LRScheduler): 5 | def __init__(self, optimizer, last_epoch=-1): 6 | super(ConstantLR, self).__init__(optimizer, last_epoch) 7 | 8 | def get_lr(self): 9 | return [base_lr for base_lr in self.base_lrs] 10 | 11 | 12 | class PolynomialLR(_LRScheduler): 13 | def __init__(self, optimizer, max_iter, decay_iter=1, gamma=0.9, last_epoch=-1): 14 | self.decay_iter = decay_iter 15 | self.max_iter = max_iter 16 | self.gamma = gamma 17 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 18 | 19 | def get_lr(self): 20 | if self.last_epoch % self.decay_iter or self.last_epoch % self.max_iter: 21 | return [base_lr for base_lr in self.base_lrs] 22 | else: 23 | factor = (1 - self.last_epoch / float(self.max_iter)) ** self.gamma 24 | return [base_lr * factor for base_lr in self.base_lrs] 25 | 26 | 27 | class WarmUpLR(_LRScheduler): 28 | def __init__( 29 | self, optimizer, scheduler, mode="linear", warmup_iters=100, gamma=0.2, last_epoch=-1 30 | ): 31 | self.mode = mode 32 | self.scheduler = scheduler 33 | self.warmup_iters = warmup_iters 34 | self.gamma = gamma 35 | super(WarmUpLR, self).__init__(optimizer, last_epoch) 36 | 37 | def get_lr(self): 38 | cold_lrs = self.scheduler.get_lr() 39 | 40 | if self.last_epoch < self.warmup_iters: 41 | if self.mode == "linear": 42 | alpha = self.last_epoch / float(self.warmup_iters) 43 | factor = self.gamma * (1 - alpha) + alpha 44 | 45 | elif self.mode == "constant": 46 | factor = self.gamma 47 | else: 48 | raise KeyError("WarmUp type {} not implemented".format(self.mode)) 49 | 50 | return [factor * base_lr for base_lr in cold_lrs] 51 | 52 | return cold_lrs 53 | -------------------------------------------------------------------------------- /ptsemseg/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc Utility functions 3 | """ 4 | import os 5 | import logging 6 | import datetime 7 | import numpy as np 8 | 9 | from collections import OrderedDict 10 | 11 | 12 | def recursive_glob(rootdir=".", suffix=""): 13 | """Performs recursive glob with given suffix and rootdir 14 | :param rootdir is the root directory 15 | :param suffix is the suffix to be searched 16 | """ 17 | return [ 18 | os.path.join(looproot, filename) 19 | for looproot, _, filenames in os.walk(rootdir) 20 | for filename in filenames 21 | if filename.endswith(suffix) 22 | ] 23 | 24 | 25 | def alpha_blend(input_image, segmentation_mask, alpha=0.5): 26 | """Alpha Blending utility to overlay RGB masks on RBG images 27 | :param input_image is a np.ndarray with 3 channels 28 | :param segmentation_mask is a np.ndarray with 3 channels 29 | :param alpha is a float value 30 | """ 31 | blended = np.zeros(input_image.size, dtype=np.float32) 32 | blended = input_image * alpha + segmentation_mask * (1 - alpha) 33 | return blended 34 | 35 | 36 | def convert_state_dict(state_dict): 37 | """Converts a state dict saved from a dataParallel module to normal 38 | module state_dict inplace 39 | :param state_dict is the loaded DataParallel model_state 40 | """ 41 | if not next(iter(state_dict)).startswith("module."): 42 | return state_dict # abort if dict is not a DataParallel model_state 43 | new_state_dict = OrderedDict() 44 | for k, v in state_dict.items(): 45 | name = k[7:] # remove `module.` 46 | new_state_dict[name] = v 47 | return new_state_dict 48 | 49 | 50 | def get_logger(logdir): 51 | logger = logging.getLogger("ptsemseg") 52 | ts = str(datetime.datetime.now()).split(".")[0].replace(" ", "_") 53 | ts = ts.replace(":", "_").replace("-", "_") 54 | file_path = os.path.join(logdir, "run_{}.log".format(ts)) 55 | hdlr = logging.FileHandler(file_path) 56 | formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") 57 | hdlr.setFormatter(formatter) 58 | logger.addHandler(hdlr) 59 | logger.setLevel(logging.INFO) 60 | return logger 61 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==2.0.0 2 | numpy==1.12.1 3 | scipy==0.19.0 4 | torch==0.4.1 5 | torchvision==0.2.0 6 | tqdm==4.11.2 7 | pydensecrf 8 | protobuf 9 | tensorboardX 10 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import numpy as np 5 | import scipy.misc as misc 6 | 7 | 8 | from ptsemseg.models import get_model 9 | from ptsemseg.loader import get_loader 10 | from ptsemseg.utils import convert_state_dict 11 | 12 | try: 13 | import pydensecrf.densecrf as dcrf 14 | except: 15 | print( 16 | "Failed to import pydensecrf,\ 17 | CRF post-processing will not work" 18 | ) 19 | 20 | 21 | def test(args): 22 | 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | 25 | model_file_name = os.path.split(args.model_path)[1] 26 | model_name = model_file_name[: model_file_name.find("_")] 27 | 28 | # Setup image 29 | print("Read Input Image from : {}".format(args.img_path)) 30 | img = misc.imread(args.img_path) 31 | 32 | data_loader = get_loader(args.dataset) 33 | loader = data_loader(root=None, is_transform=True, img_norm=args.img_norm, test_mode=True) 34 | n_classes = loader.n_classes 35 | 36 | resized_img = misc.imresize(img, (loader.img_size[0], loader.img_size[1]), interp="bicubic") 37 | 38 | orig_size = img.shape[:-1] 39 | if model_name in ["pspnet", "icnet", "icnetBN"]: 40 | # uint8 with RGB mode, resize width and height which are odd numbers 41 | img = misc.imresize(img, (orig_size[0] // 2 * 2 + 1, orig_size[1] // 2 * 2 + 1)) 42 | else: 43 | img = misc.imresize(img, (loader.img_size[0], loader.img_size[1])) 44 | 45 | img = img[:, :, ::-1] 46 | img = img.astype(np.float64) 47 | img -= loader.mean 48 | if args.img_norm: 49 | img = img.astype(float) / 255.0 50 | 51 | # NHWC -> NCHW 52 | img = img.transpose(2, 0, 1) 53 | img = np.expand_dims(img, 0) 54 | img = torch.from_numpy(img).float() 55 | 56 | # Setup Model 57 | model_dict = {"arch": model_name} 58 | model = get_model(model_dict, n_classes, version=args.dataset) 59 | state = convert_state_dict(torch.load(args.model_path)["model_state"]) 60 | model.load_state_dict(state) 61 | model.eval() 62 | model.to(device) 63 | 64 | images = img.to(device) 65 | outputs = model(images) 66 | 67 | if args.dcrf: 68 | unary = outputs.data.cpu().numpy() 69 | unary = np.squeeze(unary, 0) 70 | unary = -np.log(unary) 71 | unary = unary.transpose(2, 1, 0) 72 | w, h, c = unary.shape 73 | unary = unary.transpose(2, 0, 1).reshape(loader.n_classes, -1) 74 | unary = np.ascontiguousarray(unary) 75 | 76 | resized_img = np.ascontiguousarray(resized_img) 77 | 78 | d = dcrf.DenseCRF2D(w, h, loader.n_classes) 79 | d.setUnaryEnergy(unary) 80 | d.addPairwiseBilateral(sxy=5, srgb=3, rgbim=resized_img, compat=1) 81 | 82 | q = d.inference(50) 83 | mask = np.argmax(q, axis=0).reshape(w, h).transpose(1, 0) 84 | decoded_crf = loader.decode_segmap(np.array(mask, dtype=np.uint8)) 85 | dcrf_path = args.out_path[:-4] + "_drf.png" 86 | misc.imsave(dcrf_path, decoded_crf) 87 | print("Dense CRF Processed Mask Saved at: {}".format(dcrf_path)) 88 | 89 | pred = np.squeeze(outputs.data.max(1)[1].cpu().numpy(), axis=0) 90 | if model_name in ["pspnet", "icnet", "icnetBN"]: 91 | pred = pred.astype(np.float32) 92 | # float32 with F mode, resize back to orig_size 93 | pred = misc.imresize(pred, orig_size, "nearest", mode="F") 94 | 95 | decoded = loader.decode_segmap(pred) 96 | print("Classes found: ", np.unique(pred)) 97 | misc.imsave(args.out_path, decoded) 98 | print("Segmentation Mask Saved at: {}".format(args.out_path)) 99 | 100 | 101 | if __name__ == "__main__": 102 | parser = argparse.ArgumentParser(description="Params") 103 | parser.add_argument( 104 | "--model_path", 105 | nargs="?", 106 | type=str, 107 | default="fcn8s_pascal_1_26.pkl", 108 | help="Path to the saved model", 109 | ) 110 | parser.add_argument( 111 | "--dataset", 112 | nargs="?", 113 | type=str, 114 | default="pascal", 115 | help="Dataset to use ['pascal, camvid, ade20k etc']", 116 | ) 117 | 118 | parser.add_argument( 119 | "--img_norm", 120 | dest="img_norm", 121 | action="store_true", 122 | help="Enable input image scales normalization [0, 1] \ 123 | | True by default", 124 | ) 125 | parser.add_argument( 126 | "--no-img_norm", 127 | dest="img_norm", 128 | action="store_false", 129 | help="Disable input image scales normalization [0, 1] |\ 130 | True by default", 131 | ) 132 | parser.set_defaults(img_norm=True) 133 | 134 | parser.add_argument( 135 | "--dcrf", 136 | dest="dcrf", 137 | action="store_true", 138 | help="Enable DenseCRF based post-processing | \ 139 | False by default", 140 | ) 141 | parser.add_argument( 142 | "--no-dcrf", 143 | dest="dcrf", 144 | action="store_false", 145 | help="Disable DenseCRF based post-processing | \ 146 | False by default", 147 | ) 148 | parser.set_defaults(dcrf=False) 149 | 150 | parser.add_argument( 151 | "--img_path", nargs="?", type=str, default=None, help="Path of the input image" 152 | ) 153 | parser.add_argument( 154 | "--out_path", nargs="?", type=str, default=None, help="Path of the output segmap" 155 | ) 156 | args = parser.parse_args() 157 | test(args) 158 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import time 4 | import shutil 5 | import torch 6 | import random 7 | import argparse 8 | import numpy as np 9 | 10 | from torch.utils import data 11 | from tqdm import tqdm 12 | 13 | from ptsemseg.models import get_model 14 | from ptsemseg.loss import get_loss_function 15 | from ptsemseg.loader import get_loader 16 | from ptsemseg.utils import get_logger 17 | from ptsemseg.metrics import runningScore, averageMeter 18 | from ptsemseg.augmentations import get_composed_augmentations 19 | from ptsemseg.schedulers import get_scheduler 20 | from ptsemseg.optimizers import get_optimizer 21 | 22 | from tensorboardX import SummaryWriter 23 | 24 | 25 | def train(cfg, writer, logger): 26 | 27 | # Setup seeds 28 | torch.manual_seed(cfg.get("seed", 1337)) 29 | torch.cuda.manual_seed(cfg.get("seed", 1337)) 30 | np.random.seed(cfg.get("seed", 1337)) 31 | random.seed(cfg.get("seed", 1337)) 32 | 33 | # Setup device 34 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 35 | 36 | # Setup Augmentations 37 | augmentations = cfg["training"].get("augmentations", None) 38 | data_aug = get_composed_augmentations(augmentations) 39 | 40 | # Setup Dataloader 41 | data_loader = get_loader(cfg["data"]["dataset"]) 42 | data_path = cfg["data"]["path"] 43 | 44 | t_loader = data_loader( 45 | data_path, 46 | is_transform=True, 47 | split=cfg["data"]["train_split"], 48 | img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), 49 | augmentations=data_aug, 50 | ) 51 | 52 | v_loader = data_loader( 53 | data_path, 54 | is_transform=True, 55 | split=cfg["data"]["val_split"], 56 | img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), 57 | ) 58 | 59 | n_classes = t_loader.n_classes 60 | trainloader = data.DataLoader( 61 | t_loader, 62 | batch_size=cfg["training"]["batch_size"], 63 | num_workers=cfg["training"]["n_workers"], 64 | shuffle=True, 65 | ) 66 | 67 | valloader = data.DataLoader( 68 | v_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"] 69 | ) 70 | 71 | # Setup Metrics 72 | running_metrics_val = runningScore(n_classes) 73 | 74 | # Setup Model 75 | model = get_model(cfg["model"], n_classes).to(device) 76 | 77 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 78 | 79 | # Setup optimizer, lr_scheduler and loss function 80 | optimizer_cls = get_optimizer(cfg) 81 | optimizer_params = {k: v for k, v in cfg["training"]["optimizer"].items() if k != "name"} 82 | 83 | optimizer = optimizer_cls(model.parameters(), **optimizer_params) 84 | logger.info("Using optimizer {}".format(optimizer)) 85 | 86 | scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"]) 87 | 88 | loss_fn = get_loss_function(cfg) 89 | logger.info("Using loss {}".format(loss_fn)) 90 | 91 | start_iter = 0 92 | if cfg["training"]["resume"] is not None: 93 | if os.path.isfile(cfg["training"]["resume"]): 94 | logger.info( 95 | "Loading model and optimizer from checkpoint '{}'".format(cfg["training"]["resume"]) 96 | ) 97 | checkpoint = torch.load(cfg["training"]["resume"]) 98 | model.load_state_dict(checkpoint["model_state"]) 99 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 100 | scheduler.load_state_dict(checkpoint["scheduler_state"]) 101 | start_iter = checkpoint["epoch"] 102 | logger.info( 103 | "Loaded checkpoint '{}' (iter {})".format( 104 | cfg["training"]["resume"], checkpoint["epoch"] 105 | ) 106 | ) 107 | else: 108 | logger.info("No checkpoint found at '{}'".format(cfg["training"]["resume"])) 109 | 110 | val_loss_meter = averageMeter() 111 | time_meter = averageMeter() 112 | 113 | best_iou = -100.0 114 | i = start_iter 115 | flag = True 116 | 117 | while i <= cfg["training"]["train_iters"] and flag: 118 | for (images, labels) in trainloader: 119 | i += 1 120 | start_ts = time.time() 121 | scheduler.step() 122 | model.train() 123 | images = images.to(device) 124 | labels = labels.to(device) 125 | 126 | optimizer.zero_grad() 127 | outputs = model(images) 128 | 129 | loss = loss_fn(input=outputs, target=labels) 130 | 131 | loss.backward() 132 | optimizer.step() 133 | 134 | time_meter.update(time.time() - start_ts) 135 | 136 | if (i + 1) % cfg["training"]["print_interval"] == 0: 137 | fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" 138 | print_str = fmt_str.format( 139 | i + 1, 140 | cfg["training"]["train_iters"], 141 | loss.item(), 142 | time_meter.avg / cfg["training"]["batch_size"], 143 | ) 144 | 145 | print(print_str) 146 | logger.info(print_str) 147 | writer.add_scalar("loss/train_loss", loss.item(), i + 1) 148 | time_meter.reset() 149 | 150 | if (i + 1) % cfg["training"]["val_interval"] == 0 or (i + 1) == cfg["training"][ 151 | "train_iters" 152 | ]: 153 | model.eval() 154 | with torch.no_grad(): 155 | for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)): 156 | images_val = images_val.to(device) 157 | labels_val = labels_val.to(device) 158 | 159 | outputs = model(images_val) 160 | val_loss = loss_fn(input=outputs, target=labels_val) 161 | 162 | pred = outputs.data.max(1)[1].cpu().numpy() 163 | gt = labels_val.data.cpu().numpy() 164 | 165 | running_metrics_val.update(gt, pred) 166 | val_loss_meter.update(val_loss.item()) 167 | 168 | writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1) 169 | logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg)) 170 | 171 | score, class_iou = running_metrics_val.get_scores() 172 | for k, v in score.items(): 173 | print(k, v) 174 | logger.info("{}: {}".format(k, v)) 175 | writer.add_scalar("val_metrics/{}".format(k), v, i + 1) 176 | 177 | for k, v in class_iou.items(): 178 | logger.info("{}: {}".format(k, v)) 179 | writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1) 180 | 181 | val_loss_meter.reset() 182 | running_metrics_val.reset() 183 | 184 | if score["Mean IoU : \t"] >= best_iou: 185 | best_iou = score["Mean IoU : \t"] 186 | state = { 187 | "epoch": i + 1, 188 | "model_state": model.state_dict(), 189 | "optimizer_state": optimizer.state_dict(), 190 | "scheduler_state": scheduler.state_dict(), 191 | "best_iou": best_iou, 192 | } 193 | save_path = os.path.join( 194 | writer.file_writer.get_logdir(), 195 | "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]), 196 | ) 197 | torch.save(state, save_path) 198 | 199 | if (i + 1) == cfg["training"]["train_iters"]: 200 | flag = False 201 | break 202 | 203 | 204 | if __name__ == "__main__": 205 | parser = argparse.ArgumentParser(description="config") 206 | parser.add_argument( 207 | "--config", 208 | nargs="?", 209 | type=str, 210 | default="configs/fcn8s_pascal.yml", 211 | help="Configuration file to use", 212 | ) 213 | 214 | args = parser.parse_args() 215 | 216 | with open(args.config) as fp: 217 | cfg = yaml.load(fp) 218 | 219 | run_id = random.randint(1, 100000) 220 | logdir = os.path.join("runs", os.path.basename(args.config)[:-4], str(run_id)) 221 | writer = SummaryWriter(log_dir=logdir) 222 | 223 | print("RUNDIR: {}".format(logdir)) 224 | shutil.copy(args.config, logdir) 225 | 226 | logger = get_logger(logdir) 227 | logger.info("Let the games begin") 228 | 229 | train(cfg, writer, logger) 230 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | import argparse 4 | import timeit 5 | import numpy as np 6 | 7 | from torch.utils import data 8 | 9 | 10 | from ptsemseg.models import get_model 11 | from ptsemseg.loader import get_loader 12 | from ptsemseg.metrics import runningScore 13 | from ptsemseg.utils import convert_state_dict 14 | 15 | torch.backends.cudnn.benchmark = True 16 | 17 | 18 | def validate(cfg, args): 19 | 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | # Setup Dataloader 23 | data_loader = get_loader(cfg["data"]["dataset"]) 24 | data_path = cfg["data"]["path"] 25 | 26 | loader = data_loader( 27 | data_path, 28 | split=cfg["data"]["val_split"], 29 | is_transform=True, 30 | img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), 31 | ) 32 | 33 | n_classes = loader.n_classes 34 | 35 | valloader = data.DataLoader(loader, batch_size=cfg["training"]["batch_size"], num_workers=8) 36 | running_metrics = runningScore(n_classes) 37 | 38 | # Setup Model 39 | 40 | model = get_model(cfg["model"], n_classes).to(device) 41 | state = convert_state_dict(torch.load(args.model_path)["model_state"]) 42 | model.load_state_dict(state) 43 | model.eval() 44 | model.to(device) 45 | 46 | for i, (images, labels) in enumerate(valloader): 47 | start_time = timeit.default_timer() 48 | 49 | images = images.to(device) 50 | 51 | if args.eval_flip: 52 | outputs = model(images) 53 | 54 | # Flip images in numpy (not support in tensor) 55 | outputs = outputs.data.cpu().numpy() 56 | flipped_images = np.copy(images.data.cpu().numpy()[:, :, :, ::-1]) 57 | flipped_images = torch.from_numpy(flipped_images).float().to(device) 58 | outputs_flipped = model(flipped_images) 59 | outputs_flipped = outputs_flipped.data.cpu().numpy() 60 | outputs = (outputs + outputs_flipped[:, :, :, ::-1]) / 2.0 61 | 62 | pred = np.argmax(outputs, axis=1) 63 | else: 64 | outputs = model(images) 65 | pred = outputs.data.max(1)[1].cpu().numpy() 66 | 67 | gt = labels.numpy() 68 | 69 | if args.measure_time: 70 | elapsed_time = timeit.default_timer() - start_time 71 | print( 72 | "Inference time \ 73 | (iter {0:5d}): {1:3.5f} fps".format( 74 | i + 1, pred.shape[0] / elapsed_time 75 | ) 76 | ) 77 | running_metrics.update(gt, pred) 78 | 79 | score, class_iou = running_metrics.get_scores() 80 | 81 | for k, v in score.items(): 82 | print(k, v) 83 | 84 | for i in range(n_classes): 85 | print(i, class_iou[i]) 86 | 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser(description="Hyperparams") 90 | parser.add_argument( 91 | "--config", 92 | nargs="?", 93 | type=str, 94 | default="configs/fcn8s_pascal.yml", 95 | help="Config file to be used", 96 | ) 97 | parser.add_argument( 98 | "--model_path", 99 | nargs="?", 100 | type=str, 101 | default="fcn8s_pascal_1_26.pkl", 102 | help="Path to the saved model", 103 | ) 104 | parser.add_argument( 105 | "--eval_flip", 106 | dest="eval_flip", 107 | action="store_true", 108 | help="Enable evaluation with flipped image |\ 109 | True by default", 110 | ) 111 | parser.add_argument( 112 | "--no-eval_flip", 113 | dest="eval_flip", 114 | action="store_false", 115 | help="Disable evaluation with flipped image |\ 116 | True by default", 117 | ) 118 | parser.set_defaults(eval_flip=True) 119 | 120 | parser.add_argument( 121 | "--measure_time", 122 | dest="measure_time", 123 | action="store_true", 124 | help="Enable evaluation with time (fps) measurement |\ 125 | True by default", 126 | ) 127 | parser.add_argument( 128 | "--no-measure_time", 129 | dest="measure_time", 130 | action="store_false", 131 | help="Disable evaluation with time (fps) measurement |\ 132 | True by default", 133 | ) 134 | parser.set_defaults(measure_time=True) 135 | 136 | args = parser.parse_args() 137 | 138 | with open(args.config) as fp: 139 | cfg = yaml.load(fp) 140 | 141 | validate(cfg, args) 142 | --------------------------------------------------------------------------------