├── .gitignore
├── LICENSE
├── README.md
├── configs
├── fcn8s_pascal.yml
└── frrnB_cityscapes.yml
├── ptsemseg
├── __init__.py
├── augmentations
│ ├── __init__.py
│ └── augmentations.py
├── caffe_pb2.py
├── loader
│ ├── __init__.py
│ ├── ade20k_loader.py
│ ├── camvid_loader.py
│ ├── cityscapes_loader.py
│ ├── mapillary_vistas_loader.py
│ ├── mit_sceneparsing_benchmark_loader.py
│ ├── nyuv2_loader.py
│ ├── pascal_voc_loader.py
│ └── sunrgbd_loader.py
├── loss
│ ├── __init__.py
│ └── loss.py
├── metrics.py
├── models
│ ├── __init__.py
│ ├── fcn.py
│ ├── frrn.py
│ ├── icnet.py
│ ├── linknet.py
│ ├── pspnet.py
│ ├── refinenet.py
│ ├── segnet.py
│ ├── unet.py
│ └── utils.py
├── optimizers
│ └── __init__.py
├── schedulers
│ ├── __init__.py
│ └── schedulers.py
└── utils.py
├── requirements.txt
├── test.py
├── train.py
└── validate.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Torch Models
7 | *.pkl
8 | *.pth
9 | current_train.py
10 | video_test*.py
11 | *.swp
12 | data
13 | ckpt
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | local_test.py
21 | .DS_STORE
22 | .idea/
23 | .vscode/
24 | env/
25 | build/
26 | develop-eggs/
27 | dist/
28 | downloads/
29 | eggs/
30 | .eggs/
31 | lib/
32 | lib64/
33 | parts/
34 | sdist/
35 | var/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *,cover
59 | .hypothesis/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # IPython Notebook
83 | .ipynb_checkpoints
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # celery beat schedule file
89 | celerybeat-schedule
90 |
91 | # dotenv
92 | .env
93 |
94 | # virtualenv
95 | venv/
96 | ENV/
97 |
98 | # Spyder project settings
99 | .spyderproject
100 |
101 | # Rope project settings
102 | .ropeproject
103 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Meet Pragnesh Shah
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-semseg
2 |
3 | [](https://github.com/meetshah1995/pytorch-semseg/blob/master/LICENSE)
4 | [](https://pypi.python.org/pypi/pytorch-semseg/0.1.2)
5 | [](https://doi.org/10.5281/zenodo.1185075)
6 |
7 |
8 |
9 | ## Semantic Segmentation Algorithms Implemented in PyTorch
10 |
11 | This repository aims at mirroring popular semantic segmentation architectures in PyTorch.
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | ### Networks implemented
21 |
22 | * [PSPNet](https://arxiv.org/abs/1612.01105) - With support for loading pretrained models w/o caffe dependency
23 | * [ICNet](https://arxiv.org/pdf/1704.08545.pdf) - With optional batchnorm and pretrained models
24 | * [FRRN](https://arxiv.org/abs/1611.08323) - Model A and B
25 | * [FCN](https://arxiv.org/abs/1411.4038) - All 1 (FCN32s), 2 (FCN16s) and 3 (FCN8s) stream variants
26 | * [U-Net](https://arxiv.org/abs/1505.04597) - With optional deconvolution and batchnorm
27 | * [Link-Net](https://codeac29.github.io/projects/linknet/) - With multiple resnet backends
28 | * [Segnet](https://arxiv.org/abs/1511.00561) - With Unpooling using Maxpool indices
29 |
30 |
31 | #### Upcoming
32 |
33 | * [E-Net](https://arxiv.org/abs/1606.02147)
34 | * [RefineNet](https://arxiv.org/abs/1611.06612)
35 |
36 | ### DataLoaders implemented
37 |
38 | * [CamVid](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/)
39 | * [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/segexamples/index.html)
40 | * [ADE20K](http://groups.csail.mit.edu/vision/datasets/ADE20K/)
41 | * [MIT Scene Parsing Benchmark](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip)
42 | * [Cityscapes](https://www.cityscapes-dataset.com/)
43 | * [NYUDv2](http://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html)
44 | * [Sun-RGBD](http://rgbd.cs.princeton.edu/)
45 |
46 |
47 | ### Requirements
48 |
49 | * pytorch >=0.4.0
50 | * torchvision ==0.2.0
51 | * scipy
52 | * tqdm
53 | * tensorboardX
54 |
55 | #### One-line installation
56 |
57 | `pip install -r requirements.txt`
58 |
59 | ### Data
60 |
61 | * Download data for desired dataset(s) from list of URLs [here](https://meetshah1995.github.io/semantic-segmentation/deep-learning/pytorch/visdom/2017/06/01/semantic-segmentation-over-the-years.html#sec_datasets).
62 | * Extract the zip / tar and modify the path appropriately in your `config.yaml`
63 |
64 |
65 | ### Usage
66 |
67 | **Setup config file**
68 |
69 | ```yaml
70 | # Model Configuration
71 | model:
72 | arch: [options: 'fcn[8,16,32]s, unet, segnet, pspnet, icnet, icnetBN, linknet, frrn[A,B]'
73 | :
74 |
75 | # Data Configuration
76 | data:
77 | dataset: [options: 'pascal, camvid, ade20k, mit_sceneparsing_benchmark, cityscapes, nyuv2, sunrgbd, vistas']
78 | train_split:
79 | val_split:
80 | img_rows: 512
81 | img_cols: 1024
82 | path:
83 | :
84 |
85 | # Training Configuration
86 | training:
87 | n_workers: 64
88 | train_iters: 35000
89 | batch_size: 16
90 | val_interval: 500
91 | print_interval: 25
92 | loss:
93 | name: [options: 'cross_entropy, bootstrapped_cross_entropy, multi_scale_crossentropy']
94 | :
95 |
96 | # Optmizer Configuration
97 | optimizer:
98 | name: [options: 'sgd, adam, adamax, asgd, adadelta, adagrad, rmsprop']
99 | lr: 1.0e-3
100 | :
101 |
102 | # Warmup LR Configuration
103 | warmup_iters:
104 | mode: <'constant' or 'linear' for warmup'>
105 | gamma:
106 |
107 | # Augmentations Configuration
108 | augmentations:
109 | gamma: x #[gamma varied in 1 to 1+x]
110 | hue: x #[hue varied in -x to x]
111 | brightness: x #[brightness varied in 1-x to 1+x]
112 | saturation: x #[saturation varied in 1-x to 1+x]
113 | contrast: x #[contrast varied in 1-x to 1+x]
114 | rcrop: [h, w] #[crop of size (h,w)]
115 | translate: [dh, dw] #[reflective translation by (dh, dw)]
116 | rotate: d #[rotate -d to d degrees]
117 | scale: [h,w] #[scale to size (h,w)]
118 | ccrop: [h,w] #[center crop of (h,w)]
119 | hflip: p #[flip horizontally with chance p]
120 | vflip: p #[flip vertically with chance p]
121 |
122 | # LR Schedule Configuration
123 | lr_schedule:
124 | name: [options: 'constant_lr, poly_lr, multi_step, cosine_annealing, exp_lr']
125 | :
126 |
127 | # Resume from checkpoint
128 | resume:
129 | ```
130 |
131 | **To train the model :**
132 |
133 | ```
134 | python train.py [-h] [--config [CONFIG]]
135 |
136 | --config Configuration file to use
137 | ```
138 |
139 | **To validate the model :**
140 |
141 | ```
142 | usage: validate.py [-h] [--config [CONFIG]] [--model_path [MODEL_PATH]]
143 | [--eval_flip] [--measure_time]
144 |
145 | --config Config file to be used
146 | --model_path Path to the saved model
147 | --eval_flip Enable evaluation with flipped image | True by default
148 | --measure_time Enable evaluation with time (fps) measurement | True
149 | by default
150 | ```
151 |
152 | **To test the model w.r.t. a dataset on custom images(s):**
153 |
154 | ```
155 | python test.py [-h] [--model_path [MODEL_PATH]] [--dataset [DATASET]]
156 | [--dcrf [DCRF]] [--img_path [IMG_PATH]] [--out_path [OUT_PATH]]
157 |
158 | --model_path Path to the saved model
159 | --dataset Dataset to use ['pascal, camvid, ade20k etc']
160 | --dcrf Enable DenseCRF based post-processing
161 | --img_path Path of the input image
162 | --out_path Path of the output segmap
163 | ```
164 |
165 |
166 | **If you find this code useful in your research, please consider citing:**
167 |
168 | ```
169 | @article{mshahsemseg,
170 | Author = {Meet P Shah},
171 | Title = {Semantic Segmentation Architectures Implemented in PyTorch.},
172 | Journal = {https://github.com/meetshah1995/pytorch-semseg},
173 | Year = {2017}
174 | }
175 | ```
176 |
177 |
--------------------------------------------------------------------------------
/configs/fcn8s_pascal.yml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: fcn8s
3 | data:
4 | dataset: pascal
5 | train_split: train_aug
6 | val_split: val
7 | img_rows: 'same'
8 | img_cols: 'same'
9 | path: /private/home/meetshah/datasets/VOC/060817/VOCdevkit/VOC2012/
10 | sbd_path: /private/home/meetshah/datasets/VOC/benchmark_RELEASE/
11 | training:
12 | train_iters: 300000
13 | batch_size: 1
14 | val_interval: 1000
15 | n_workers: 16
16 | print_interval: 50
17 | optimizer:
18 | name: 'sgd'
19 | lr: 1.0e-10
20 | weight_decay: 0.0005
21 | momentum: 0.99
22 | loss:
23 | name: 'cross_entropy'
24 | size_average: False
25 | lr_schedule:
26 | resume: fcn8s_pascal_best_model.pkl
27 |
--------------------------------------------------------------------------------
/configs/frrnB_cityscapes.yml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: frrnB
3 | data:
4 | dataset: cityscapes
5 | train_split: train
6 | val_split: val
7 | img_rows: 512
8 | img_cols: 1024
9 | path: /private/home/meetshah/misc_code/ps/data/VOCdevkit/VOC2012/
10 | training:
11 | train_iters: 85000
12 | batch_size: 2
13 | val_interval: 500
14 | print_interval: 25
15 | optimizer:
16 | lr: 1.0e-4
17 | l_rate: 1.0e-4
18 | l_schedule:
19 | momentum: 0.99
20 | weight_decay: 0.0005
21 | resume: frrnB_cityscapes_best_model.pkl
22 | visdom: False
23 |
--------------------------------------------------------------------------------
/ptsemseg/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/meetps/pytorch-semseg/801fb200547caa5b0d91b8dde56b837da029f746/ptsemseg/__init__.py
--------------------------------------------------------------------------------
/ptsemseg/augmentations/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from ptsemseg.augmentations.augmentations import (
3 | AdjustContrast,
4 | AdjustGamma,
5 | AdjustBrightness,
6 | AdjustSaturation,
7 | AdjustHue,
8 | RandomCrop,
9 | RandomHorizontallyFlip,
10 | RandomVerticallyFlip,
11 | Scale,
12 | RandomSized,
13 | RandomSizedCrop,
14 | RandomRotate,
15 | RandomTranslate,
16 | CenterCrop,
17 | Compose,
18 | )
19 |
20 | logger = logging.getLogger("ptsemseg")
21 |
22 | key2aug = {
23 | "gamma": AdjustGamma,
24 | "hue": AdjustHue,
25 | "brightness": AdjustBrightness,
26 | "saturation": AdjustSaturation,
27 | "contrast": AdjustContrast,
28 | "rcrop": RandomCrop,
29 | "hflip": RandomHorizontallyFlip,
30 | "vflip": RandomVerticallyFlip,
31 | "scale": Scale,
32 | "rsize": RandomSized,
33 | "rsizecrop": RandomSizedCrop,
34 | "rotate": RandomRotate,
35 | "translate": RandomTranslate,
36 | "ccrop": CenterCrop,
37 | }
38 |
39 |
40 | def get_composed_augmentations(aug_dict):
41 | if aug_dict is None:
42 | logger.info("Using No Augmentations")
43 | return None
44 |
45 | augmentations = []
46 | for aug_key, aug_param in aug_dict.items():
47 | augmentations.append(key2aug[aug_key](aug_param))
48 | logger.info("Using {} aug with params {}".format(aug_key, aug_param))
49 | return Compose(augmentations)
50 |
--------------------------------------------------------------------------------
/ptsemseg/augmentations/augmentations.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numbers
3 | import random
4 | import numpy as np
5 | import torchvision.transforms.functional as tf
6 |
7 | from PIL import Image, ImageOps
8 |
9 |
10 | class Compose(object):
11 | def __init__(self, augmentations):
12 | self.augmentations = augmentations
13 | self.PIL2Numpy = False
14 |
15 | def __call__(self, img, mask):
16 | if isinstance(img, np.ndarray):
17 | img = Image.fromarray(img, mode="RGB")
18 | mask = Image.fromarray(mask, mode="L")
19 | self.PIL2Numpy = True
20 |
21 | assert img.size == mask.size
22 | for a in self.augmentations:
23 | img, mask = a(img, mask)
24 |
25 | if self.PIL2Numpy:
26 | img, mask = np.array(img), np.array(mask, dtype=np.uint8)
27 |
28 | return img, mask
29 |
30 |
31 | class RandomCrop(object):
32 | def __init__(self, size, padding=0):
33 | if isinstance(size, numbers.Number):
34 | self.size = (int(size), int(size))
35 | else:
36 | self.size = size
37 | self.padding = padding
38 |
39 | def __call__(self, img, mask):
40 | if self.padding > 0:
41 | img = ImageOps.expand(img, border=self.padding, fill=0)
42 | mask = ImageOps.expand(mask, border=self.padding, fill=0)
43 |
44 | assert img.size == mask.size
45 | w, h = img.size
46 | th, tw = self.size
47 | if w == tw and h == th:
48 | return img, mask
49 | if w < tw or h < th:
50 | return (img.resize((tw, th), Image.BILINEAR), mask.resize((tw, th), Image.NEAREST))
51 |
52 | x1 = random.randint(0, w - tw)
53 | y1 = random.randint(0, h - th)
54 | return (img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)))
55 |
56 |
57 | class AdjustGamma(object):
58 | def __init__(self, gamma):
59 | self.gamma = gamma
60 |
61 | def __call__(self, img, mask):
62 | assert img.size == mask.size
63 | return tf.adjust_gamma(img, random.uniform(1, 1 + self.gamma)), mask
64 |
65 |
66 | class AdjustSaturation(object):
67 | def __init__(self, saturation):
68 | self.saturation = saturation
69 |
70 | def __call__(self, img, mask):
71 | assert img.size == mask.size
72 | return (
73 | tf.adjust_saturation(img, random.uniform(1 - self.saturation, 1 + self.saturation)),
74 | mask,
75 | )
76 |
77 |
78 | class AdjustHue(object):
79 | def __init__(self, hue):
80 | self.hue = hue
81 |
82 | def __call__(self, img, mask):
83 | assert img.size == mask.size
84 | return tf.adjust_hue(img, random.uniform(-self.hue, self.hue)), mask
85 |
86 |
87 | class AdjustBrightness(object):
88 | def __init__(self, bf):
89 | self.bf = bf
90 |
91 | def __call__(self, img, mask):
92 | assert img.size == mask.size
93 | return tf.adjust_brightness(img, random.uniform(1 - self.bf, 1 + self.bf)), mask
94 |
95 |
96 | class AdjustContrast(object):
97 | def __init__(self, cf):
98 | self.cf = cf
99 |
100 | def __call__(self, img, mask):
101 | assert img.size == mask.size
102 | return tf.adjust_contrast(img, random.uniform(1 - self.cf, 1 + self.cf)), mask
103 |
104 |
105 | class CenterCrop(object):
106 | def __init__(self, size):
107 | if isinstance(size, numbers.Number):
108 | self.size = (int(size), int(size))
109 | else:
110 | self.size = size
111 |
112 | def __call__(self, img, mask):
113 | assert img.size == mask.size
114 | w, h = img.size
115 | th, tw = self.size
116 | x1 = int(round((w - tw) / 2.0))
117 | y1 = int(round((h - th) / 2.0))
118 | return (img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)))
119 |
120 |
121 | class RandomHorizontallyFlip(object):
122 | def __init__(self, p):
123 | self.p = p
124 |
125 | def __call__(self, img, mask):
126 | if random.random() < self.p:
127 | return (img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT))
128 | return img, mask
129 |
130 |
131 | class RandomVerticallyFlip(object):
132 | def __init__(self, p):
133 | self.p = p
134 |
135 | def __call__(self, img, mask):
136 | if random.random() < self.p:
137 | return (img.transpose(Image.FLIP_TOP_BOTTOM), mask.transpose(Image.FLIP_TOP_BOTTOM))
138 | return img, mask
139 |
140 |
141 | class FreeScale(object):
142 | def __init__(self, size):
143 | self.size = tuple(reversed(size)) # size: (h, w)
144 |
145 | def __call__(self, img, mask):
146 | assert img.size == mask.size
147 | return (img.resize(self.size, Image.BILINEAR), mask.resize(self.size, Image.NEAREST))
148 |
149 |
150 | class RandomTranslate(object):
151 | def __init__(self, offset):
152 | # tuple (delta_x, delta_y)
153 | self.offset = offset
154 |
155 | def __call__(self, img, mask):
156 | assert img.size == mask.size
157 | x_offset = int(2 * (random.random() - 0.5) * self.offset[0])
158 | y_offset = int(2 * (random.random() - 0.5) * self.offset[1])
159 |
160 | x_crop_offset = x_offset
161 | y_crop_offset = y_offset
162 | if x_offset < 0:
163 | x_crop_offset = 0
164 | if y_offset < 0:
165 | y_crop_offset = 0
166 |
167 | cropped_img = tf.crop(
168 | img,
169 | y_crop_offset,
170 | x_crop_offset,
171 | img.size[1] - abs(y_offset),
172 | img.size[0] - abs(x_offset),
173 | )
174 |
175 | if x_offset >= 0 and y_offset >= 0:
176 | padding_tuple = (0, 0, x_offset, y_offset)
177 |
178 | elif x_offset >= 0 and y_offset < 0:
179 | padding_tuple = (0, abs(y_offset), x_offset, 0)
180 |
181 | elif x_offset < 0 and y_offset >= 0:
182 | padding_tuple = (abs(x_offset), 0, 0, y_offset)
183 |
184 | elif x_offset < 0 and y_offset < 0:
185 | padding_tuple = (abs(x_offset), abs(y_offset), 0, 0)
186 |
187 | return (
188 | tf.pad(cropped_img, padding_tuple, padding_mode="reflect"),
189 | tf.affine(
190 | mask,
191 | translate=(-x_offset, -y_offset),
192 | scale=1.0,
193 | angle=0.0,
194 | shear=0.0,
195 | fillcolor=250,
196 | ),
197 | )
198 |
199 |
200 | class RandomRotate(object):
201 | def __init__(self, degree):
202 | self.degree = degree
203 |
204 | def __call__(self, img, mask):
205 | rotate_degree = random.random() * 2 * self.degree - self.degree
206 | return (
207 | tf.affine(
208 | img,
209 | translate=(0, 0),
210 | scale=1.0,
211 | angle=rotate_degree,
212 | resample=Image.BILINEAR,
213 | fillcolor=(0, 0, 0),
214 | shear=0.0,
215 | ),
216 | tf.affine(
217 | mask,
218 | translate=(0, 0),
219 | scale=1.0,
220 | angle=rotate_degree,
221 | resample=Image.NEAREST,
222 | fillcolor=250,
223 | shear=0.0,
224 | ),
225 | )
226 |
227 |
228 | class Scale(object):
229 | def __init__(self, size):
230 | self.size = size
231 |
232 | def __call__(self, img, mask):
233 | assert img.size == mask.size
234 | w, h = img.size
235 | if (w >= h and w == self.size) or (h >= w and h == self.size):
236 | return img, mask
237 | if w > h:
238 | ow = self.size
239 | oh = int(self.size * h / w)
240 | return (img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST))
241 | else:
242 | oh = self.size
243 | ow = int(self.size * w / h)
244 | return (img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST))
245 |
246 |
247 | class RandomSizedCrop(object):
248 | def __init__(self, size):
249 | self.size = size
250 |
251 | def __call__(self, img, mask):
252 | assert img.size == mask.size
253 | for attempt in range(10):
254 | area = img.size[0] * img.size[1]
255 | target_area = random.uniform(0.45, 1.0) * area
256 | aspect_ratio = random.uniform(0.5, 2)
257 |
258 | w = int(round(math.sqrt(target_area * aspect_ratio)))
259 | h = int(round(math.sqrt(target_area / aspect_ratio)))
260 |
261 | if random.random() < 0.5:
262 | w, h = h, w
263 |
264 | if w <= img.size[0] and h <= img.size[1]:
265 | x1 = random.randint(0, img.size[0] - w)
266 | y1 = random.randint(0, img.size[1] - h)
267 |
268 | img = img.crop((x1, y1, x1 + w, y1 + h))
269 | mask = mask.crop((x1, y1, x1 + w, y1 + h))
270 | assert img.size == (w, h)
271 |
272 | return (
273 | img.resize((self.size, self.size), Image.BILINEAR),
274 | mask.resize((self.size, self.size), Image.NEAREST),
275 | )
276 |
277 | # Fallback
278 | scale = Scale(self.size)
279 | crop = CenterCrop(self.size)
280 | return crop(*scale(img, mask))
281 |
282 |
283 | class RandomSized(object):
284 | def __init__(self, size):
285 | self.size = size
286 | self.scale = Scale(self.size)
287 | self.crop = RandomCrop(self.size)
288 |
289 | def __call__(self, img, mask):
290 | assert img.size == mask.size
291 |
292 | w = int(random.uniform(0.5, 2) * img.size[0])
293 | h = int(random.uniform(0.5, 2) * img.size[1])
294 |
295 | img, mask = (img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST))
296 |
297 | return self.crop(*self.scale(img, mask))
298 |
--------------------------------------------------------------------------------
/ptsemseg/loader/__init__.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from ptsemseg.loader.pascal_voc_loader import pascalVOCLoader
4 | from ptsemseg.loader.camvid_loader import camvidLoader
5 | from ptsemseg.loader.ade20k_loader import ADE20KLoader
6 | from ptsemseg.loader.mit_sceneparsing_benchmark_loader import MITSceneParsingBenchmarkLoader
7 | from ptsemseg.loader.cityscapes_loader import cityscapesLoader
8 | from ptsemseg.loader.nyuv2_loader import NYUv2Loader
9 | from ptsemseg.loader.sunrgbd_loader import SUNRGBDLoader
10 | from ptsemseg.loader.mapillary_vistas_loader import mapillaryVistasLoader
11 |
12 |
13 | def get_loader(name):
14 | """get_loader
15 |
16 | :param name:
17 | """
18 | return {
19 | "pascal": pascalVOCLoader,
20 | "camvid": camvidLoader,
21 | "ade20k": ADE20KLoader,
22 | "mit_sceneparsing_benchmark": MITSceneParsingBenchmarkLoader,
23 | "cityscapes": cityscapesLoader,
24 | "nyuv2": NYUv2Loader,
25 | "sunrgbd": SUNRGBDLoader,
26 | "vistas": mapillaryVistasLoader,
27 | }[name]
28 |
--------------------------------------------------------------------------------
/ptsemseg/loader/ade20k_loader.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import torch
3 | import torchvision
4 | import numpy as np
5 | import scipy.misc as m
6 | import matplotlib.pyplot as plt
7 |
8 | from torch.utils import data
9 |
10 | from ptsemseg.utils import recursive_glob
11 |
12 |
13 | class ADE20KLoader(data.Dataset):
14 | def __init__(
15 | self,
16 | root,
17 | split="training",
18 | is_transform=False,
19 | img_size=512,
20 | augmentations=None,
21 | img_norm=True,
22 | test_mode=False,
23 | ):
24 | self.root = root
25 | self.split = split
26 | self.is_transform = is_transform
27 | self.augmentations = augmentations
28 | self.img_norm = img_norm
29 | self.test_mode = test_mode
30 | self.n_classes = 150
31 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
32 | self.mean = np.array([104.00699, 116.66877, 122.67892])
33 | self.files = collections.defaultdict(list)
34 |
35 | if not self.test_mode:
36 | for split in ["training", "validation"]:
37 | file_list = recursive_glob(
38 | rootdir=self.root + "images/" + self.split + "/", suffix=".jpg"
39 | )
40 | self.files[split] = file_list
41 |
42 | def __len__(self):
43 | return len(self.files[self.split])
44 |
45 | def __getitem__(self, index):
46 | img_path = self.files[self.split][index].rstrip()
47 | lbl_path = img_path[:-4] + "_seg.png"
48 |
49 | img = m.imread(img_path)
50 | img = np.array(img, dtype=np.uint8)
51 |
52 | lbl = m.imread(lbl_path)
53 | lbl = np.array(lbl, dtype=np.int32)
54 |
55 | if self.augmentations is not None:
56 | img, lbl = self.augmentations(img, lbl)
57 |
58 | if self.is_transform:
59 | img, lbl = self.transform(img, lbl)
60 |
61 | return img, lbl
62 |
63 | def transform(self, img, lbl):
64 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode
65 | img = img[:, :, ::-1] # RGB -> BGR
66 | img = img.astype(np.float64)
67 | img -= self.mean
68 | if self.img_norm:
69 | # Resize scales images from 0 to 255, thus we need
70 | # to divide by 255.0
71 | img = img.astype(float) / 255.0
72 | # NHWC -> NCHW
73 | img = img.transpose(2, 0, 1)
74 |
75 | lbl = self.encode_segmap(lbl)
76 | classes = np.unique(lbl)
77 | lbl = lbl.astype(float)
78 | lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F")
79 | lbl = lbl.astype(int)
80 | assert np.all(classes == np.unique(lbl))
81 |
82 | img = torch.from_numpy(img).float()
83 | lbl = torch.from_numpy(lbl).long()
84 | return img, lbl
85 |
86 | def encode_segmap(self, mask):
87 | # Refer : http://groups.csail.mit.edu/vision/datasets/ADE20K/code/loadAde20K.m
88 | mask = mask.astype(int)
89 | label_mask = np.zeros((mask.shape[0], mask.shape[1]))
90 | label_mask = (mask[:, :, 0] / 10.0) * 256 + mask[:, :, 1]
91 | return np.array(label_mask, dtype=np.uint8)
92 |
93 | def decode_segmap(self, temp, plot=False):
94 | # TODO:(@meetshah1995)
95 | # Verify that the color mapping is 1-to-1
96 | r = temp.copy()
97 | g = temp.copy()
98 | b = temp.copy()
99 | for l in range(0, self.n_classes):
100 | r[temp == l] = 10 * (l % 10)
101 | g[temp == l] = l
102 | b[temp == l] = 0
103 |
104 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
105 | rgb[:, :, 0] = r / 255.0
106 | rgb[:, :, 1] = g / 255.0
107 | rgb[:, :, 2] = b / 255.0
108 | if plot:
109 | plt.imshow(rgb)
110 | plt.show()
111 | else:
112 | return rgb
113 |
114 |
115 | if __name__ == "__main__":
116 | local_path = "/Users/meet/data/ADE20K_2016_07_26/"
117 | dst = ADE20KLoader(local_path, is_transform=True)
118 | trainloader = data.DataLoader(dst, batch_size=4)
119 | for i, data_samples in enumerate(trainloader):
120 | imgs, labels = data_samples
121 | if i == 0:
122 | img = torchvision.utils.make_grid(imgs).numpy()
123 | img = np.transpose(img, (1, 2, 0))
124 | img = img[:, :, ::-1]
125 | plt.imshow(img)
126 | plt.show()
127 | for j in range(4):
128 | plt.imshow(dst.decode_segmap(labels.numpy()[j]))
129 | plt.show()
130 |
--------------------------------------------------------------------------------
/ptsemseg/loader/camvid_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import collections
3 | import torch
4 | import numpy as np
5 | import scipy.misc as m
6 | import matplotlib.pyplot as plt
7 |
8 | from torch.utils import data
9 | from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate
10 |
11 |
12 | class camvidLoader(data.Dataset):
13 | def __init__(
14 | self,
15 | root,
16 | split="train",
17 | is_transform=False,
18 | img_size=None,
19 | augmentations=None,
20 | img_norm=True,
21 | test_mode=False,
22 | ):
23 | self.root = root
24 | self.split = split
25 | self.img_size = [360, 480]
26 | self.is_transform = is_transform
27 | self.augmentations = augmentations
28 | self.img_norm = img_norm
29 | self.test_mode = test_mode
30 | self.mean = np.array([104.00699, 116.66877, 122.67892])
31 | self.n_classes = 12
32 | self.files = collections.defaultdict(list)
33 |
34 | if not self.test_mode:
35 | for split in ["train", "test", "val"]:
36 | file_list = os.listdir(root + "/" + split)
37 | self.files[split] = file_list
38 |
39 | def __len__(self):
40 | return len(self.files[self.split])
41 |
42 | def __getitem__(self, index):
43 | img_name = self.files[self.split][index]
44 | img_path = self.root + "/" + self.split + "/" + img_name
45 | lbl_path = self.root + "/" + self.split + "annot/" + img_name
46 |
47 | img = m.imread(img_path)
48 | img = np.array(img, dtype=np.uint8)
49 |
50 | lbl = m.imread(lbl_path)
51 | lbl = np.array(lbl, dtype=np.int8)
52 |
53 | if self.augmentations is not None:
54 | img, lbl = self.augmentations(img, lbl)
55 |
56 | if self.is_transform:
57 | img, lbl = self.transform(img, lbl)
58 |
59 | return img, lbl
60 |
61 | def transform(self, img, lbl):
62 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode
63 | img = img[:, :, ::-1] # RGB -> BGR
64 | img = img.astype(np.float64)
65 | img -= self.mean
66 | if self.img_norm:
67 | # Resize scales images from 0 to 255, thus we need
68 | # to divide by 255.0
69 | img = img.astype(float) / 255.0
70 | # NHWC -> NCHW
71 | img = img.transpose(2, 0, 1)
72 |
73 | img = torch.from_numpy(img).float()
74 | lbl = torch.from_numpy(lbl).long()
75 | return img, lbl
76 |
77 | def decode_segmap(self, temp, plot=False):
78 | Sky = [128, 128, 128]
79 | Building = [128, 0, 0]
80 | Pole = [192, 192, 128]
81 | Road = [128, 64, 128]
82 | Pavement = [60, 40, 222]
83 | Tree = [128, 128, 0]
84 | SignSymbol = [192, 128, 128]
85 | Fence = [64, 64, 128]
86 | Car = [64, 0, 128]
87 | Pedestrian = [64, 64, 0]
88 | Bicyclist = [0, 128, 192]
89 | Unlabelled = [0, 0, 0]
90 |
91 | label_colours = np.array(
92 | [
93 | Sky,
94 | Building,
95 | Pole,
96 | Road,
97 | Pavement,
98 | Tree,
99 | SignSymbol,
100 | Fence,
101 | Car,
102 | Pedestrian,
103 | Bicyclist,
104 | Unlabelled,
105 | ]
106 | )
107 | r = temp.copy()
108 | g = temp.copy()
109 | b = temp.copy()
110 | for l in range(0, self.n_classes):
111 | r[temp == l] = label_colours[l, 0]
112 | g[temp == l] = label_colours[l, 1]
113 | b[temp == l] = label_colours[l, 2]
114 |
115 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
116 | rgb[:, :, 0] = r / 255.0
117 | rgb[:, :, 1] = g / 255.0
118 | rgb[:, :, 2] = b / 255.0
119 | return rgb
120 |
121 |
122 | if __name__ == "__main__":
123 | local_path = "/home/meetshah1995/datasets/segnet/CamVid"
124 | augmentations = Compose([RandomRotate(10), RandomHorizontallyFlip()])
125 |
126 | dst = camvidLoader(local_path, is_transform=True, augmentations=augmentations)
127 | bs = 4
128 | trainloader = data.DataLoader(dst, batch_size=bs)
129 | for i, data_samples in enumerate(trainloader):
130 | imgs, labels = data_samples
131 | imgs = imgs.numpy()[:, ::-1, :, :]
132 | imgs = np.transpose(imgs, [0, 2, 3, 1])
133 | f, axarr = plt.subplots(bs, 2)
134 | for j in range(bs):
135 | axarr[j][0].imshow(imgs[j])
136 | axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j]))
137 | plt.show()
138 | a = input()
139 | if a == "ex":
140 | break
141 | else:
142 | plt.close()
143 |
--------------------------------------------------------------------------------
/ptsemseg/loader/cityscapes_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import scipy.misc as m
5 |
6 | from torch.utils import data
7 |
8 | from ptsemseg.utils import recursive_glob
9 | from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, Scale
10 |
11 |
12 | class cityscapesLoader(data.Dataset):
13 | """cityscapesLoader
14 |
15 | https://www.cityscapes-dataset.com
16 |
17 | Data is derived from CityScapes, and can be downloaded from here:
18 | https://www.cityscapes-dataset.com/downloads/
19 |
20 | Many Thanks to @fvisin for the loader repo:
21 | https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/cityscapes.py
22 | """
23 |
24 | colors = [ # [ 0, 0, 0],
25 | [128, 64, 128],
26 | [244, 35, 232],
27 | [70, 70, 70],
28 | [102, 102, 156],
29 | [190, 153, 153],
30 | [153, 153, 153],
31 | [250, 170, 30],
32 | [220, 220, 0],
33 | [107, 142, 35],
34 | [152, 251, 152],
35 | [0, 130, 180],
36 | [220, 20, 60],
37 | [255, 0, 0],
38 | [0, 0, 142],
39 | [0, 0, 70],
40 | [0, 60, 100],
41 | [0, 80, 100],
42 | [0, 0, 230],
43 | [119, 11, 32],
44 | ]
45 |
46 | label_colours = dict(zip(range(19), colors))
47 |
48 | mean_rgb = {
49 | "pascal": [103.939, 116.779, 123.68],
50 | "cityscapes": [0.0, 0.0, 0.0],
51 | } # pascal mean for PSPNet and ICNet pre-trained model
52 |
53 | def __init__(
54 | self,
55 | root,
56 | split="train",
57 | is_transform=False,
58 | img_size=(512, 1024),
59 | augmentations=None,
60 | img_norm=True,
61 | version="cityscapes",
62 | test_mode=False,
63 | ):
64 | """__init__
65 |
66 | :param root:
67 | :param split:
68 | :param is_transform:
69 | :param img_size:
70 | :param augmentations
71 | """
72 | self.root = root
73 | self.split = split
74 | self.is_transform = is_transform
75 | self.augmentations = augmentations
76 | self.img_norm = img_norm
77 | self.n_classes = 19
78 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
79 | self.mean = np.array(self.mean_rgb[version])
80 | self.files = {}
81 |
82 | self.images_base = os.path.join(self.root, "leftImg8bit", self.split)
83 | self.annotations_base = os.path.join(self.root, "gtFine", self.split)
84 |
85 | self.files[split] = recursive_glob(rootdir=self.images_base, suffix=".png")
86 |
87 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
88 | self.valid_classes = [
89 | 7,
90 | 8,
91 | 11,
92 | 12,
93 | 13,
94 | 17,
95 | 19,
96 | 20,
97 | 21,
98 | 22,
99 | 23,
100 | 24,
101 | 25,
102 | 26,
103 | 27,
104 | 28,
105 | 31,
106 | 32,
107 | 33,
108 | ]
109 | self.class_names = [
110 | "unlabelled",
111 | "road",
112 | "sidewalk",
113 | "building",
114 | "wall",
115 | "fence",
116 | "pole",
117 | "traffic_light",
118 | "traffic_sign",
119 | "vegetation",
120 | "terrain",
121 | "sky",
122 | "person",
123 | "rider",
124 | "car",
125 | "truck",
126 | "bus",
127 | "train",
128 | "motorcycle",
129 | "bicycle",
130 | ]
131 |
132 | self.ignore_index = 250
133 | self.class_map = dict(zip(self.valid_classes, range(19)))
134 |
135 | if not self.files[split]:
136 | raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))
137 |
138 | print("Found %d %s images" % (len(self.files[split]), split))
139 |
140 | def __len__(self):
141 | """__len__"""
142 | return len(self.files[self.split])
143 |
144 | def __getitem__(self, index):
145 | """__getitem__
146 |
147 | :param index:
148 | """
149 | img_path = self.files[self.split][index].rstrip()
150 | lbl_path = os.path.join(
151 | self.annotations_base,
152 | img_path.split(os.sep)[-2],
153 | os.path.basename(img_path)[:-15] + "gtFine_labelIds.png",
154 | )
155 |
156 | img = m.imread(img_path)
157 | img = np.array(img, dtype=np.uint8)
158 |
159 | lbl = m.imread(lbl_path)
160 | lbl = self.encode_segmap(np.array(lbl, dtype=np.uint8))
161 |
162 | if self.augmentations is not None:
163 | img, lbl = self.augmentations(img, lbl)
164 |
165 | if self.is_transform:
166 | img, lbl = self.transform(img, lbl)
167 |
168 | return img, lbl
169 |
170 | def transform(self, img, lbl):
171 | """transform
172 |
173 | :param img:
174 | :param lbl:
175 | """
176 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode
177 | img = img[:, :, ::-1] # RGB -> BGR
178 | img = img.astype(np.float64)
179 | img -= self.mean
180 | if self.img_norm:
181 | # Resize scales images from 0 to 255, thus we need
182 | # to divide by 255.0
183 | img = img.astype(float) / 255.0
184 | # NHWC -> NCHW
185 | img = img.transpose(2, 0, 1)
186 |
187 | classes = np.unique(lbl)
188 | lbl = lbl.astype(float)
189 | lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F")
190 | lbl = lbl.astype(int)
191 |
192 | if not np.all(classes == np.unique(lbl)):
193 | print("WARN: resizing labels yielded fewer classes")
194 |
195 | if not np.all(np.unique(lbl[lbl != self.ignore_index]) < self.n_classes):
196 | print("after det", classes, np.unique(lbl))
197 | raise ValueError("Segmentation map contained invalid class values")
198 |
199 | img = torch.from_numpy(img).float()
200 | lbl = torch.from_numpy(lbl).long()
201 |
202 | return img, lbl
203 |
204 | def decode_segmap(self, temp):
205 | r = temp.copy()
206 | g = temp.copy()
207 | b = temp.copy()
208 | for l in range(0, self.n_classes):
209 | r[temp == l] = self.label_colours[l][0]
210 | g[temp == l] = self.label_colours[l][1]
211 | b[temp == l] = self.label_colours[l][2]
212 |
213 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
214 | rgb[:, :, 0] = r / 255.0
215 | rgb[:, :, 1] = g / 255.0
216 | rgb[:, :, 2] = b / 255.0
217 | return rgb
218 |
219 | def encode_segmap(self, mask):
220 | # Put all void classes to zero
221 | for _voidc in self.void_classes:
222 | mask[mask == _voidc] = self.ignore_index
223 | for _validc in self.valid_classes:
224 | mask[mask == _validc] = self.class_map[_validc]
225 | return mask
226 |
227 |
228 | if __name__ == "__main__":
229 | import matplotlib.pyplot as plt
230 |
231 | augmentations = Compose([Scale(2048), RandomRotate(10), RandomHorizontallyFlip(0.5)])
232 |
233 | local_path = "/datasets01/cityscapes/112817/"
234 | dst = cityscapesLoader(local_path, is_transform=True, augmentations=augmentations)
235 | bs = 4
236 | trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0)
237 | for i, data_samples in enumerate(trainloader):
238 | imgs, labels = data_samples
239 | import pdb
240 |
241 | pdb.set_trace()
242 | imgs = imgs.numpy()[:, ::-1, :, :]
243 | imgs = np.transpose(imgs, [0, 2, 3, 1])
244 | f, axarr = plt.subplots(bs, 2)
245 | for j in range(bs):
246 | axarr[j][0].imshow(imgs[j])
247 | axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j]))
248 | plt.show()
249 | a = input()
250 | if a == "ex":
251 | break
252 | else:
253 | plt.close()
254 |
--------------------------------------------------------------------------------
/ptsemseg/loader/mapillary_vistas_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import numpy as np
5 |
6 | from torch.utils import data
7 | from PIL import Image
8 |
9 | from ptsemseg.utils import recursive_glob
10 | from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate
11 |
12 |
13 | class mapillaryVistasLoader(data.Dataset):
14 | def __init__(
15 | self,
16 | root,
17 | split="training",
18 | img_size=(640, 1280),
19 | is_transform=True,
20 | augmentations=None,
21 | test_mode=False,
22 | ):
23 | self.root = root
24 | self.split = split
25 | self.is_transform = is_transform
26 | self.augmentations = augmentations
27 | self.n_classes = 65
28 |
29 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
30 | self.mean = np.array([80.5423, 91.3162, 81.4312])
31 | self.files = {}
32 |
33 | self.images_base = os.path.join(self.root, self.split, "images")
34 | self.annotations_base = os.path.join(self.root, self.split, "labels")
35 |
36 | self.files[split] = recursive_glob(rootdir=self.images_base, suffix=".jpg")
37 |
38 | self.class_ids, self.class_names, self.class_colors = self.parse_config()
39 |
40 | self.ignore_id = 250
41 |
42 | if not self.files[split]:
43 | raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))
44 |
45 | print("Found %d %s images" % (len(self.files[split]), split))
46 |
47 | def parse_config(self):
48 | with open(os.path.join(self.root, "config.json")) as config_file:
49 | config = json.load(config_file)
50 |
51 | labels = config["labels"]
52 |
53 | class_names = []
54 | class_ids = []
55 | class_colors = []
56 | print("There are {} labels in the config file".format(len(labels)))
57 | for label_id, label in enumerate(labels):
58 | class_names.append(label["readable"])
59 | class_ids.append(label_id)
60 | class_colors.append(label["color"])
61 |
62 | return class_names, class_ids, class_colors
63 |
64 | def __len__(self):
65 | """__len__"""
66 | return len(self.files[self.split])
67 |
68 | def __getitem__(self, index):
69 | """__getitem__
70 | :param index:
71 | """
72 | img_path = self.files[self.split][index].rstrip()
73 | lbl_path = os.path.join(
74 | self.annotations_base, os.path.basename(img_path).replace(".jpg", ".png")
75 | )
76 |
77 | img = Image.open(img_path)
78 | lbl = Image.open(lbl_path)
79 |
80 | if self.augmentations is not None:
81 | img, lbl = self.augmentations(img, lbl)
82 |
83 | if self.is_transform:
84 | img, lbl = self.transform(img, lbl)
85 |
86 | return img, lbl
87 |
88 | def transform(self, img, lbl):
89 | if self.img_size == ("same", "same"):
90 | pass
91 | else:
92 | img = img.resize(
93 | (self.img_size[0], self.img_size[1]), resample=Image.LANCZOS
94 | ) # uint8 with RGB mode
95 | lbl = lbl.resize((self.img_size[0], self.img_size[1]))
96 | img = np.array(img).astype(np.float64) / 255.0
97 | img = torch.from_numpy(img.transpose(2, 0, 1)).float() # From HWC to CHW
98 | lbl = torch.from_numpy(np.array(lbl)).long()
99 | lbl[lbl == 65] = self.ignore_id
100 | return img, lbl
101 |
102 | def decode_segmap(self, temp):
103 | r = temp.copy()
104 | g = temp.copy()
105 | b = temp.copy()
106 | for l in range(0, self.n_classes):
107 | r[temp == l] = self.class_colors[l][0]
108 | g[temp == l] = self.class_colors[l][1]
109 | b[temp == l] = self.class_colors[l][2]
110 |
111 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
112 | rgb[:, :, 0] = r / 255.0
113 | rgb[:, :, 1] = g / 255.0
114 | rgb[:, :, 2] = b / 255.0
115 | return rgb
116 |
117 |
118 | if __name__ == "__main__":
119 | augment = Compose([RandomHorizontallyFlip(), RandomRotate(6)])
120 |
121 | local_path = "/private/home/meetshah/datasets/seg/vistas/"
122 | dst = mapillaryVistasLoader(
123 | local_path, img_size=(512, 1024), is_transform=True, augmentations=augment
124 | )
125 | bs = 8
126 | trainloader = data.DataLoader(dst, batch_size=bs, num_workers=4, shuffle=True)
127 | for i, data_samples in enumerate(trainloader):
128 | x = dst.decode_segmap(data_samples[1][0].numpy())
129 | print("batch :", i)
130 |
--------------------------------------------------------------------------------
/ptsemseg/loader/mit_sceneparsing_benchmark_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import scipy.misc as m
5 |
6 | from torch.utils import data
7 |
8 | from ptsemseg.utils import recursive_glob
9 |
10 |
11 | class MITSceneParsingBenchmarkLoader(data.Dataset):
12 | """MITSceneParsingBenchmarkLoader
13 |
14 | http://sceneparsing.csail.mit.edu/
15 |
16 | Data is derived from ADE20k, and can be downloaded from here:
17 | http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip
18 |
19 | NOTE: this loader is not designed to work with the original ADE20k dataset;
20 | for that you will need the ADE20kLoader
21 |
22 | This class can also be extended to load data for places challenge:
23 | https://github.com/CSAILVision/placeschallenge/tree/master/sceneparsing
24 |
25 | """
26 |
27 | def __init__(
28 | self,
29 | root,
30 | split="training",
31 | is_transform=False,
32 | img_size=512,
33 | augmentations=None,
34 | img_norm=True,
35 | test_mode=False,
36 | ):
37 | """__init__
38 |
39 | :param root:
40 | :param split:
41 | :param is_transform:
42 | :param img_size:
43 | """
44 | self.root = root
45 | self.split = split
46 | self.is_transform = is_transform
47 | self.augmentations = augmentations
48 | self.img_norm = img_norm
49 | self.n_classes = 151 # 0 is reserved for "other"
50 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
51 | self.mean = np.array([104.00699, 116.66877, 122.67892])
52 | self.files = {}
53 |
54 | self.images_base = os.path.join(self.root, "images", self.split)
55 | self.annotations_base = os.path.join(self.root, "annotations", self.split)
56 |
57 | self.files[split] = recursive_glob(rootdir=self.images_base, suffix=".jpg")
58 |
59 | if not self.files[split]:
60 | raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))
61 |
62 | print("Found %d %s images" % (len(self.files[split]), split))
63 |
64 | def __len__(self):
65 | """__len__"""
66 | return len(self.files[self.split])
67 |
68 | def __getitem__(self, index):
69 | """__getitem__
70 |
71 | :param index:
72 | """
73 | img_path = self.files[self.split][index].rstrip()
74 | lbl_path = os.path.join(self.annotations_base, os.path.basename(img_path)[:-4] + ".png")
75 |
76 | img = m.imread(img_path, mode="RGB")
77 | img = np.array(img, dtype=np.uint8)
78 |
79 | lbl = m.imread(lbl_path)
80 | lbl = np.array(lbl, dtype=np.uint8)
81 |
82 | if self.augmentations is not None:
83 | img, lbl = self.augmentations(img, lbl)
84 |
85 | if self.is_transform:
86 | img, lbl = self.transform(img, lbl)
87 |
88 | return img, lbl
89 |
90 | def transform(self, img, lbl):
91 | """transform
92 |
93 | :param img:
94 | :param lbl:
95 | """
96 | if self.img_size == ("same", "same"):
97 | pass
98 | else:
99 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode
100 | img = img[:, :, ::-1] # RGB -> BGR
101 | img = img.astype(np.float64)
102 | img -= self.mean
103 | if self.img_norm:
104 | # Resize scales images from 0 to 255, thus we need
105 | # to divide by 255.0
106 | img = img.astype(float) / 255.0
107 | # NHWC -> NCHW
108 | img = img.transpose(2, 0, 1)
109 |
110 | classes = np.unique(lbl)
111 | lbl = lbl.astype(float)
112 | if self.img_size == ("same", "same"):
113 | pass
114 | else:
115 | lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F")
116 | lbl = lbl.astype(int)
117 |
118 | if not np.all(classes == np.unique(lbl)):
119 | print("WARN: resizing labels yielded fewer classes")
120 |
121 | if not np.all(np.unique(lbl) < self.n_classes):
122 | raise ValueError("Segmentation map contained invalid class values")
123 |
124 | img = torch.from_numpy(img).float()
125 | lbl = torch.from_numpy(lbl).long()
126 |
127 | return img, lbl
128 |
--------------------------------------------------------------------------------
/ptsemseg/loader/nyuv2_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import collections
3 | import torch
4 | import numpy as np
5 | import scipy.misc as m
6 |
7 | from torch.utils import data
8 |
9 | from ptsemseg.utils import recursive_glob
10 | from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, Scale
11 |
12 |
13 | class NYUv2Loader(data.Dataset):
14 | """
15 | NYUv2 loader
16 | Download From (only 13 classes):
17 | test source: http://www.doc.ic.ac.uk/~ahanda/nyu_test_rgb.tgz
18 | train source: http://www.doc.ic.ac.uk/~ahanda/nyu_train_rgb.tgz
19 | test_labels source:
20 | https://github.com/ankurhanda/nyuv2-meta-data/raw/master/test_labels_13/nyuv2_test_class13.tgz
21 | train_labels source:
22 | https://github.com/ankurhanda/nyuv2-meta-data/raw/master/train_labels_13/nyuv2_train_class13.tgz
23 |
24 | """
25 |
26 | def __init__(
27 | self,
28 | root,
29 | split="training",
30 | is_transform=False,
31 | img_size=(480, 640),
32 | augmentations=None,
33 | img_norm=True,
34 | test_mode=False,
35 | ):
36 | self.root = root
37 | self.is_transform = is_transform
38 | self.n_classes = 14
39 | self.augmentations = augmentations
40 | self.img_norm = img_norm
41 | self.test_mode = test_mode
42 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
43 | self.mean = np.array([104.00699, 116.66877, 122.67892])
44 | self.files = collections.defaultdict(list)
45 | self.cmap = self.color_map(normalized=False)
46 |
47 | split_map = {"training": "train", "val": "test"}
48 | self.split = split_map[split]
49 |
50 | for split in ["train", "test"]:
51 | file_list = recursive_glob(rootdir=self.root + split + "/", suffix="png")
52 | self.files[split] = file_list
53 |
54 | def __len__(self):
55 | return len(self.files[self.split])
56 |
57 | def __getitem__(self, index):
58 | img_path = self.files[self.split][index].rstrip()
59 | img_number = img_path.split("_")[-1][:4]
60 | lbl_path = os.path.join(
61 | self.root, self.split + "_annot", "new_nyu_class13_" + img_number + ".png"
62 | )
63 |
64 | img = m.imread(img_path)
65 | img = np.array(img, dtype=np.uint8)
66 |
67 | lbl = m.imread(lbl_path)
68 | lbl = np.array(lbl, dtype=np.uint8)
69 |
70 | if not (len(img.shape) == 3 and len(lbl.shape) == 2):
71 | return self.__getitem__(np.random.randint(0, self.__len__()))
72 |
73 | if self.augmentations is not None:
74 | img, lbl = self.augmentations(img, lbl)
75 |
76 | if self.is_transform:
77 | img, lbl = self.transform(img, lbl)
78 |
79 | return img, lbl
80 |
81 | def transform(self, img, lbl):
82 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode
83 | img = img[:, :, ::-1] # RGB -> BGR
84 | img = img.astype(np.float64)
85 | img -= self.mean
86 | if self.img_norm:
87 | # Resize scales images from 0 to 255, thus we need
88 | # to divide by 255.0
89 | img = img.astype(float) / 255.0
90 | # NHWC -> NCHW
91 | img = img.transpose(2, 0, 1)
92 |
93 | classes = np.unique(lbl)
94 | lbl = lbl.astype(float)
95 | lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F")
96 | lbl = lbl.astype(int)
97 | assert np.all(classes == np.unique(lbl))
98 |
99 | img = torch.from_numpy(img).float()
100 | lbl = torch.from_numpy(lbl).long()
101 | return img, lbl
102 |
103 | def color_map(self, N=256, normalized=False):
104 | """
105 | Return Color Map in PASCAL VOC format
106 | """
107 |
108 | def bitget(byteval, idx):
109 | return (byteval & (1 << idx)) != 0
110 |
111 | dtype = "float32" if normalized else "uint8"
112 | cmap = np.zeros((N, 3), dtype=dtype)
113 | for i in range(N):
114 | r = g = b = 0
115 | c = i
116 | for j in range(8):
117 | r = r | (bitget(c, 0) << 7 - j)
118 | g = g | (bitget(c, 1) << 7 - j)
119 | b = b | (bitget(c, 2) << 7 - j)
120 | c = c >> 3
121 |
122 | cmap[i] = np.array([r, g, b])
123 |
124 | cmap = cmap / 255.0 if normalized else cmap
125 | return cmap
126 |
127 | def decode_segmap(self, temp):
128 | r = temp.copy()
129 | g = temp.copy()
130 | b = temp.copy()
131 | for l in range(0, self.n_classes):
132 | r[temp == l] = self.cmap[l, 0]
133 | g[temp == l] = self.cmap[l, 1]
134 | b[temp == l] = self.cmap[l, 2]
135 |
136 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
137 | rgb[:, :, 0] = r / 255.0
138 | rgb[:, :, 1] = g / 255.0
139 | rgb[:, :, 2] = b / 255.0
140 | return rgb
141 |
142 |
143 | if __name__ == "__main__":
144 | import matplotlib.pyplot as plt
145 |
146 | augmentations = Compose([Scale(512), RandomRotate(10), RandomHorizontallyFlip()])
147 |
148 | local_path = "/home/meet/datasets/NYUv2/"
149 | dst = NYUv2Loader(local_path, is_transform=True, augmentations=augmentations)
150 | bs = 4
151 | trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0)
152 | for i, datas in enumerate(trainloader):
153 | imgs, labels = datas
154 | imgs = imgs.numpy()[:, ::-1, :, :]
155 | imgs = np.transpose(imgs, [0, 2, 3, 1])
156 | f, axarr = plt.subplots(bs, 2)
157 | for j in range(bs):
158 | axarr[j][0].imshow(imgs[j])
159 | axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j]))
160 | plt.show()
161 | a = input()
162 | if a == "ex":
163 | break
164 | else:
165 | plt.close()
166 |
--------------------------------------------------------------------------------
/ptsemseg/loader/pascal_voc_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join as pjoin
3 | import collections
4 | import json
5 | import torch
6 | import numpy as np
7 | import scipy.misc as m
8 | import scipy.io as io
9 | import matplotlib.pyplot as plt
10 | import glob
11 |
12 | from PIL import Image
13 | from tqdm import tqdm
14 | from torch.utils import data
15 | from torchvision import transforms
16 |
17 |
18 | class pascalVOCLoader(data.Dataset):
19 | """Data loader for the Pascal VOC semantic segmentation dataset.
20 |
21 | Annotations from both the original VOC data (which consist of RGB images
22 | in which colours map to specific classes) and the SBD (Berkely) dataset
23 | (where annotations are stored as .mat files) are converted into a common
24 | `label_mask` format. Under this format, each mask is an (M,N) array of
25 | integer values from 0 to 21, where 0 represents the background class.
26 |
27 | The label masks are stored in a new folder, called `pre_encoded`, which
28 | is added as a subdirectory of the `SegmentationClass` folder in the
29 | original Pascal VOC data layout.
30 |
31 | A total of five data splits are provided for working with the VOC data:
32 | train: The original VOC 2012 training data - 1464 images
33 | val: The original VOC 2012 validation data - 1449 images
34 | trainval: The combination of `train` and `val` - 2913 images
35 | train_aug: The unique images present in both the train split and
36 | training images from SBD: - 8829 images (the unique members
37 | of the result of combining lists of length 1464 and 8498)
38 | train_aug_val: The original VOC 2012 validation data minus the images
39 | present in `train_aug` (This is done with the same logic as
40 | the validation set used in FCN PAMI paper, but with VOC 2012
41 | rather than VOC 2011) - 904 images
42 | """
43 |
44 | def __init__(
45 | self,
46 | root,
47 | sbd_path=None,
48 | split="train_aug",
49 | is_transform=False,
50 | img_size=512,
51 | augmentations=None,
52 | img_norm=True,
53 | test_mode=False,
54 | ):
55 | self.root = root
56 | self.sbd_path = sbd_path
57 | self.split = split
58 | self.is_transform = is_transform
59 | self.augmentations = augmentations
60 | self.img_norm = img_norm
61 | self.test_mode = test_mode
62 | self.n_classes = 21
63 | self.mean = np.array([104.00699, 116.66877, 122.67892])
64 | self.files = collections.defaultdict(list)
65 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
66 |
67 | if not self.test_mode:
68 | for split in ["train", "val", "trainval"]:
69 | path = pjoin(self.root, "ImageSets/Segmentation", split + ".txt")
70 | file_list = tuple(open(path, "r"))
71 | file_list = [id_.rstrip() for id_ in file_list]
72 | self.files[split] = file_list
73 | self.setup_annotations()
74 |
75 | self.tf = transforms.Compose(
76 | [
77 | transforms.ToTensor(),
78 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
79 | ]
80 | )
81 |
82 | def __len__(self):
83 | return len(self.files[self.split])
84 |
85 | def __getitem__(self, index):
86 | im_name = self.files[self.split][index]
87 | im_path = pjoin(self.root, "JPEGImages", im_name + ".jpg")
88 | lbl_path = pjoin(self.root, "SegmentationClass/pre_encoded", im_name + ".png")
89 | im = Image.open(im_path)
90 | lbl = Image.open(lbl_path)
91 | if self.augmentations is not None:
92 | im, lbl = self.augmentations(im, lbl)
93 | if self.is_transform:
94 | im, lbl = self.transform(im, lbl)
95 | return im, lbl
96 |
97 | def transform(self, img, lbl):
98 | if self.img_size == ("same", "same"):
99 | pass
100 | else:
101 | img = img.resize((self.img_size[0], self.img_size[1])) # uint8 with RGB mode
102 | lbl = lbl.resize((self.img_size[0], self.img_size[1]))
103 | img = self.tf(img)
104 | lbl = torch.from_numpy(np.array(lbl)).long()
105 | lbl[lbl == 255] = 0
106 | return img, lbl
107 |
108 | def get_pascal_labels(self):
109 | """Load the mapping that associates pascal classes with label colors
110 |
111 | Returns:
112 | np.ndarray with dimensions (21, 3)
113 | """
114 | return np.asarray(
115 | [
116 | [0, 0, 0],
117 | [128, 0, 0],
118 | [0, 128, 0],
119 | [128, 128, 0],
120 | [0, 0, 128],
121 | [128, 0, 128],
122 | [0, 128, 128],
123 | [128, 128, 128],
124 | [64, 0, 0],
125 | [192, 0, 0],
126 | [64, 128, 0],
127 | [192, 128, 0],
128 | [64, 0, 128],
129 | [192, 0, 128],
130 | [64, 128, 128],
131 | [192, 128, 128],
132 | [0, 64, 0],
133 | [128, 64, 0],
134 | [0, 192, 0],
135 | [128, 192, 0],
136 | [0, 64, 128],
137 | ]
138 | )
139 |
140 | def encode_segmap(self, mask):
141 | """Encode segmentation label images as pascal classes
142 |
143 | Args:
144 | mask (np.ndarray): raw segmentation label image of dimension
145 | (M, N, 3), in which the Pascal classes are encoded as colours.
146 |
147 | Returns:
148 | (np.ndarray): class map with dimensions (M,N), where the value at
149 | a given location is the integer denoting the class index.
150 | """
151 | mask = mask.astype(int)
152 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
153 | for ii, label in enumerate(self.get_pascal_labels()):
154 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
155 | label_mask = label_mask.astype(int)
156 | return label_mask
157 |
158 | def decode_segmap(self, label_mask, plot=False):
159 | """Decode segmentation class labels into a color image
160 |
161 | Args:
162 | label_mask (np.ndarray): an (M,N) array of integer values denoting
163 | the class label at each spatial location.
164 | plot (bool, optional): whether to show the resulting color image
165 | in a figure.
166 |
167 | Returns:
168 | (np.ndarray, optional): the resulting decoded color image.
169 | """
170 | label_colours = self.get_pascal_labels()
171 | r = label_mask.copy()
172 | g = label_mask.copy()
173 | b = label_mask.copy()
174 | for ll in range(0, self.n_classes):
175 | r[label_mask == ll] = label_colours[ll, 0]
176 | g[label_mask == ll] = label_colours[ll, 1]
177 | b[label_mask == ll] = label_colours[ll, 2]
178 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
179 | rgb[:, :, 0] = r / 255.0
180 | rgb[:, :, 1] = g / 255.0
181 | rgb[:, :, 2] = b / 255.0
182 | if plot:
183 | plt.imshow(rgb)
184 | plt.show()
185 | else:
186 | return rgb
187 |
188 | def setup_annotations(self):
189 | """Sets up Berkley annotations by adding image indices to the
190 | `train_aug` split and pre-encode all segmentation labels into the
191 | common label_mask format (if this has not already been done). This
192 | function also defines the `train_aug` and `train_aug_val` data splits
193 | according to the description in the class docstring
194 | """
195 | sbd_path = self.sbd_path
196 | target_path = pjoin(self.root, "SegmentationClass/pre_encoded")
197 | if not os.path.exists(target_path):
198 | os.makedirs(target_path)
199 | path = pjoin(sbd_path, "dataset/train.txt")
200 | sbd_train_list = tuple(open(path, "r"))
201 | sbd_train_list = [id_.rstrip() for id_ in sbd_train_list]
202 | train_aug = self.files["train"] + sbd_train_list
203 |
204 | # keep unique elements (stable)
205 | train_aug = [train_aug[i] for i in sorted(np.unique(train_aug, return_index=True)[1])]
206 | self.files["train_aug"] = train_aug
207 | set_diff = set(self.files["val"]) - set(train_aug) # remove overlap
208 | self.files["train_aug_val"] = list(set_diff)
209 |
210 | pre_encoded = glob.glob(pjoin(target_path, "*.png"))
211 | expected = np.unique(self.files["train_aug"] + self.files["val"]).size
212 |
213 | if len(pre_encoded) != expected:
214 | print("Pre-encoding segmentation masks...")
215 | for ii in tqdm(sbd_train_list):
216 | lbl_path = pjoin(sbd_path, "dataset/cls", ii + ".mat")
217 | data = io.loadmat(lbl_path)
218 | lbl = data["GTcls"][0]["Segmentation"][0].astype(np.int32)
219 | lbl = m.toimage(lbl, high=lbl.max(), low=lbl.min())
220 | m.imsave(pjoin(target_path, ii + ".png"), lbl)
221 |
222 | for ii in tqdm(self.files["trainval"]):
223 | fname = ii + ".png"
224 | lbl_path = pjoin(self.root, "SegmentationClass", fname)
225 | lbl = self.encode_segmap(m.imread(lbl_path))
226 | lbl = m.toimage(lbl, high=lbl.max(), low=lbl.min())
227 | m.imsave(pjoin(target_path, fname), lbl)
228 |
229 | assert expected == 9733, "unexpected dataset sizes"
230 |
231 |
232 | # Leave code for debugging purposes
233 | # import ptsemseg.augmentations as aug
234 | # if __name__ == '__main__':
235 | # # local_path = '/home/meetshah1995/datasets/VOCdevkit/VOC2012/'
236 | # bs = 4
237 | # augs = aug.Compose([aug.RandomRotate(10), aug.RandomHorizontallyFlip()])
238 | # dst = pascalVOCLoader(root=local_path, is_transform=True, augmentations=augs)
239 | # trainloader = data.DataLoader(dst, batch_size=bs)
240 | # for i, data in enumerate(trainloader):
241 | # imgs, labels = data
242 | # imgs = imgs.numpy()[:, ::-1, :, :]
243 | # imgs = np.transpose(imgs, [0,2,3,1])
244 | # f, axarr = plt.subplots(bs, 2)
245 | # for j in range(bs):
246 | # axarr[j][0].imshow(imgs[j])
247 | # axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j]))
248 | # plt.show()
249 | # a = raw_input()
250 | # if a == 'ex':
251 | # break
252 | # else:
253 | # plt.close()
254 |
--------------------------------------------------------------------------------
/ptsemseg/loader/sunrgbd_loader.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import torch
3 | import numpy as np
4 | import scipy.misc as m
5 |
6 | from torch.utils import data
7 |
8 | from ptsemseg.utils import recursive_glob
9 | from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, Scale
10 |
11 |
12 | class SUNRGBDLoader(data.Dataset):
13 | """SUNRGBD loader
14 |
15 | Download From:
16 | http://www.doc.ic.ac.uk/~ahanda/SUNRGBD-test_images.tgz
17 | test source: http://www.doc.ic.ac.uk/~ahanda/SUNRGBD-test_images.tgz
18 | train source: http://www.doc.ic.ac.uk/~ahanda/SUNRGBD-train_images.tgz
19 |
20 | first 5050 in this is test, later 5051 is train
21 | test and train labels source:
22 | https://github.com/ankurhanda/sunrgbd-meta-data/raw/master/sunrgbd_train_test_labels.tar.gz
23 | """
24 |
25 | def __init__(
26 | self,
27 | root,
28 | split="training",
29 | is_transform=False,
30 | img_size=(480, 640),
31 | augmentations=None,
32 | img_norm=True,
33 | test_mode=False,
34 | ):
35 | self.root = root
36 | self.is_transform = is_transform
37 | self.n_classes = 38
38 | self.augmentations = augmentations
39 | self.img_norm = img_norm
40 | self.test_mode = test_mode
41 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
42 | self.mean = np.array([104.00699, 116.66877, 122.67892])
43 | self.files = collections.defaultdict(list)
44 | self.anno_files = collections.defaultdict(list)
45 | self.cmap = self.color_map(normalized=False)
46 |
47 | split_map = {"training": "train", "val": "test"}
48 | self.split = split_map[split]
49 |
50 | for split in ["train", "test"]:
51 | file_list = sorted(recursive_glob(rootdir=self.root + split + "/", suffix="jpg"))
52 | self.files[split] = file_list
53 |
54 | for split in ["train", "test"]:
55 | file_list = sorted(
56 | recursive_glob(rootdir=self.root + "annotations/" + split + "/", suffix="png")
57 | )
58 | self.anno_files[split] = file_list
59 |
60 | def __len__(self):
61 | return len(self.files[self.split])
62 |
63 | def __getitem__(self, index):
64 | img_path = self.files[self.split][index].rstrip()
65 | lbl_path = self.anno_files[self.split][index].rstrip()
66 | # img_number = img_path.split('/')[-1]
67 | # lbl_path = os.path.join(self.root, 'annotations', img_number).replace('jpg', 'png')
68 |
69 | img = m.imread(img_path)
70 | img = np.array(img, dtype=np.uint8)
71 |
72 | lbl = m.imread(lbl_path)
73 | lbl = np.array(lbl, dtype=np.uint8)
74 |
75 | if not (len(img.shape) == 3 and len(lbl.shape) == 2):
76 | return self.__getitem__(np.random.randint(0, self.__len__()))
77 |
78 | if self.augmentations is not None:
79 | img, lbl = self.augmentations(img, lbl)
80 |
81 | if self.is_transform:
82 | img, lbl = self.transform(img, lbl)
83 |
84 | return img, lbl
85 |
86 | def transform(self, img, lbl):
87 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode
88 | img = img[:, :, ::-1] # RGB -> BGR
89 | img = img.astype(np.float64)
90 | img -= self.mean
91 | if self.img_norm:
92 | # Resize scales images from 0 to 255, thus we need
93 | # to divide by 255.0
94 | img = img.astype(float) / 255.0
95 | # NHWC -> NCHW
96 | img = img.transpose(2, 0, 1)
97 |
98 | classes = np.unique(lbl)
99 | lbl = lbl.astype(float)
100 | lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F")
101 | lbl = lbl.astype(int)
102 | assert np.all(classes == np.unique(lbl))
103 |
104 | img = torch.from_numpy(img).float()
105 | lbl = torch.from_numpy(lbl).long()
106 | return img, lbl
107 |
108 | def color_map(self, N=256, normalized=False):
109 | """
110 | Return Color Map in PASCAL VOC format
111 | """
112 |
113 | def bitget(byteval, idx):
114 | return (byteval & (1 << idx)) != 0
115 |
116 | dtype = "float32" if normalized else "uint8"
117 | cmap = np.zeros((N, 3), dtype=dtype)
118 | for i in range(N):
119 | r = g = b = 0
120 | c = i
121 | for j in range(8):
122 | r = r | (bitget(c, 0) << 7 - j)
123 | g = g | (bitget(c, 1) << 7 - j)
124 | b = b | (bitget(c, 2) << 7 - j)
125 | c = c >> 3
126 |
127 | cmap[i] = np.array([r, g, b])
128 |
129 | cmap = cmap / 255.0 if normalized else cmap
130 | return cmap
131 |
132 | def decode_segmap(self, temp):
133 | r = temp.copy()
134 | g = temp.copy()
135 | b = temp.copy()
136 | for l in range(0, self.n_classes):
137 | r[temp == l] = self.cmap[l, 0]
138 | g[temp == l] = self.cmap[l, 1]
139 | b[temp == l] = self.cmap[l, 2]
140 |
141 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
142 | rgb[:, :, 0] = r / 255.0
143 | rgb[:, :, 1] = g / 255.0
144 | rgb[:, :, 2] = b / 255.0
145 | return rgb
146 |
147 |
148 | if __name__ == "__main__":
149 | import matplotlib.pyplot as plt
150 |
151 | augmentations = Compose([Scale(512), RandomRotate(10), RandomHorizontallyFlip()])
152 |
153 | local_path = "/home/meet/datasets/SUNRGBD/"
154 | dst = SUNRGBDLoader(local_path, is_transform=True, augmentations=augmentations)
155 | bs = 4
156 | trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0)
157 | for i, data_samples in enumerate(trainloader):
158 | imgs, labels = data_samples
159 | imgs = imgs.numpy()[:, ::-1, :, :]
160 | imgs = np.transpose(imgs, [0, 2, 3, 1])
161 | f, axarr = plt.subplots(bs, 2)
162 | for j in range(bs):
163 | axarr[j][0].imshow(imgs[j])
164 | axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j]))
165 | plt.show()
166 | a = input()
167 | if a == "ex":
168 | break
169 | else:
170 | plt.close()
171 |
--------------------------------------------------------------------------------
/ptsemseg/loss/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import functools
3 |
4 | from ptsemseg.loss.loss import (
5 | cross_entropy2d,
6 | bootstrapped_cross_entropy2d,
7 | multi_scale_cross_entropy2d,
8 | )
9 |
10 |
11 | logger = logging.getLogger("ptsemseg")
12 |
13 | key2loss = {
14 | "cross_entropy": cross_entropy2d,
15 | "bootstrapped_cross_entropy": bootstrapped_cross_entropy2d,
16 | "multi_scale_cross_entropy": multi_scale_cross_entropy2d,
17 | }
18 |
19 |
20 | def get_loss_function(cfg):
21 | if cfg["training"]["loss"] is None:
22 | logger.info("Using default cross entropy loss")
23 | return cross_entropy2d
24 |
25 | else:
26 | loss_dict = cfg["training"]["loss"]
27 | loss_name = loss_dict["name"]
28 | loss_params = {k: v for k, v in loss_dict.items() if k != "name"}
29 |
30 | if loss_name not in key2loss:
31 | raise NotImplementedError("Loss {} not implemented".format(loss_name))
32 |
33 | logger.info("Using {} with {} params".format(loss_name, loss_params))
34 | return functools.partial(key2loss[loss_name], **loss_params)
35 |
--------------------------------------------------------------------------------
/ptsemseg/loss/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def cross_entropy2d(input, target, weight=None, size_average=True):
6 | n, c, h, w = input.size()
7 | nt, ht, wt = target.size()
8 |
9 | # Handle inconsistent size between input and target
10 | if h != ht and w != wt: # upsample labels
11 | input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)
12 |
13 | input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
14 | target = target.view(-1)
15 | loss = F.cross_entropy(
16 | input, target, weight=weight, size_average=size_average, ignore_index=250
17 | )
18 | return loss
19 |
20 |
21 | def multi_scale_cross_entropy2d(input, target, weight=None, size_average=True, scale_weight=None):
22 | if not isinstance(input, tuple):
23 | return cross_entropy2d(input=input, target=target, weight=weight, size_average=size_average)
24 |
25 | # Auxiliary training for PSPNet [1.0, 0.4] and ICNet [1.0, 0.4, 0.16]
26 | if scale_weight is None: # scale_weight: torch tensor type
27 | n_inp = len(input)
28 | scale = 0.4
29 | scale_weight = torch.pow(scale * torch.ones(n_inp), torch.arange(n_inp).float()).to(
30 | target.device
31 | )
32 |
33 | loss = 0.0
34 | for i, inp in enumerate(input):
35 | loss = loss + scale_weight[i] * cross_entropy2d(
36 | input=inp, target=target, weight=weight, size_average=size_average
37 | )
38 |
39 | return loss
40 |
41 |
42 | def bootstrapped_cross_entropy2d(input, target, K, weight=None, size_average=True):
43 |
44 | batch_size = input.size()[0]
45 |
46 | def _bootstrap_xentropy_single(input, target, K, weight=None, size_average=True):
47 |
48 | n, c, h, w = input.size()
49 | input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
50 | target = target.view(-1)
51 | loss = F.cross_entropy(
52 | input, target, weight=weight, reduce=False, size_average=False, ignore_index=250
53 | )
54 |
55 | topk_loss, _ = loss.topk(K)
56 | reduced_topk_loss = topk_loss.sum() / K
57 |
58 | return reduced_topk_loss
59 |
60 | loss = 0.0
61 | # Bootstrap from each image not entire batch
62 | for i in range(batch_size):
63 | loss += _bootstrap_xentropy_single(
64 | input=torch.unsqueeze(input[i], 0),
65 | target=torch.unsqueeze(target[i], 0),
66 | K=K,
67 | weight=weight,
68 | size_average=size_average,
69 | )
70 | return loss / float(batch_size)
71 |
--------------------------------------------------------------------------------
/ptsemseg/metrics.py:
--------------------------------------------------------------------------------
1 | # Adapted from score written by wkentaro
2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
3 |
4 | import numpy as np
5 |
6 |
7 | class runningScore(object):
8 | def __init__(self, n_classes):
9 | self.n_classes = n_classes
10 | self.confusion_matrix = np.zeros((n_classes, n_classes))
11 |
12 | def _fast_hist(self, label_true, label_pred, n_class):
13 | mask = (label_true >= 0) & (label_true < n_class)
14 | hist = np.bincount(
15 | n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2
16 | ).reshape(n_class, n_class)
17 | return hist
18 |
19 | def update(self, label_trues, label_preds):
20 | for lt, lp in zip(label_trues, label_preds):
21 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
22 |
23 | def get_scores(self):
24 | """Returns accuracy score evaluation result.
25 | - overall accuracy
26 | - mean accuracy
27 | - mean IU
28 | - fwavacc
29 | """
30 | hist = self.confusion_matrix
31 | acc = np.diag(hist).sum() / hist.sum()
32 | acc_cls = np.diag(hist) / hist.sum(axis=1)
33 | acc_cls = np.nanmean(acc_cls)
34 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
35 | mean_iu = np.nanmean(iu)
36 | freq = hist.sum(axis=1) / hist.sum()
37 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
38 | cls_iu = dict(zip(range(self.n_classes), iu))
39 |
40 | return (
41 | {
42 | "Overall Acc: \t": acc,
43 | "Mean Acc : \t": acc_cls,
44 | "FreqW Acc : \t": fwavacc,
45 | "Mean IoU : \t": mean_iu,
46 | },
47 | cls_iu,
48 | )
49 |
50 | def reset(self):
51 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
52 |
53 |
54 | class averageMeter(object):
55 | """Computes and stores the average and current value"""
56 |
57 | def __init__(self):
58 | self.reset()
59 |
60 | def reset(self):
61 | self.val = 0
62 | self.avg = 0
63 | self.sum = 0
64 | self.count = 0
65 |
66 | def update(self, val, n=1):
67 | self.val = val
68 | self.sum += val * n
69 | self.count += n
70 | self.avg = self.sum / self.count
71 |
--------------------------------------------------------------------------------
/ptsemseg/models/__init__.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torchvision.models as models
3 |
4 | from ptsemseg.models.fcn import fcn8s, fcn16s, fcn32s
5 | from ptsemseg.models.segnet import segnet
6 | from ptsemseg.models.unet import unet
7 | from ptsemseg.models.pspnet import pspnet
8 | from ptsemseg.models.icnet import icnet
9 | from ptsemseg.models.linknet import linknet
10 | from ptsemseg.models.frrn import frrn
11 |
12 |
13 | def get_model(model_dict, n_classes, version=None):
14 | name = model_dict["arch"]
15 | model = _get_model_instance(name)
16 | param_dict = copy.deepcopy(model_dict)
17 | param_dict.pop("arch")
18 |
19 | if name in ["frrnA", "frrnB"]:
20 | model = model(n_classes, **param_dict)
21 |
22 | elif name in ["fcn32s", "fcn16s", "fcn8s"]:
23 | model = model(n_classes=n_classes, **param_dict)
24 | vgg16 = models.vgg16(pretrained=True)
25 | model.init_vgg16_params(vgg16)
26 |
27 | elif name == "segnet":
28 | model = model(n_classes=n_classes, **param_dict)
29 | vgg16 = models.vgg16(pretrained=True)
30 | model.init_vgg16_params(vgg16)
31 |
32 | elif name == "unet":
33 | model = model(n_classes=n_classes, **param_dict)
34 |
35 | elif name == "pspnet":
36 | model = model(n_classes=n_classes, **param_dict)
37 |
38 | elif name == "icnet":
39 | model = model(n_classes=n_classes, **param_dict)
40 |
41 | elif name == "icnetBN":
42 | model = model(n_classes=n_classes, **param_dict)
43 |
44 | else:
45 | model = model(n_classes=n_classes, **param_dict)
46 |
47 | return model
48 |
49 |
50 | def _get_model_instance(name):
51 | try:
52 | return {
53 | "fcn32s": fcn32s,
54 | "fcn8s": fcn8s,
55 | "fcn16s": fcn16s,
56 | "unet": unet,
57 | "segnet": segnet,
58 | "pspnet": pspnet,
59 | "icnet": icnet,
60 | "icnetBN": icnet,
61 | "linknet": linknet,
62 | "frrnA": frrn,
63 | "frrnB": frrn,
64 | }[name]
65 | except:
66 | raise ("Model {} not available".format(name))
67 |
--------------------------------------------------------------------------------
/ptsemseg/models/fcn.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from ptsemseg.models.utils import get_upsampling_weight
7 | from ptsemseg.loss import cross_entropy2d
8 |
9 |
10 | # FCN32s
11 | class fcn32s(nn.Module):
12 | def __init__(self, n_classes=21, learned_billinear=False):
13 | super(fcn32s, self).__init__()
14 | self.learned_billinear = learned_billinear
15 | self.n_classes = n_classes
16 | self.loss = functools.partial(cross_entropy2d, size_average=False)
17 |
18 | self.conv_block1 = nn.Sequential(
19 | nn.Conv2d(3, 64, 3, padding=100),
20 | nn.ReLU(inplace=True),
21 | nn.Conv2d(64, 64, 3, padding=1),
22 | nn.ReLU(inplace=True),
23 | nn.MaxPool2d(2, stride=2, ceil_mode=True),
24 | )
25 |
26 | self.conv_block2 = nn.Sequential(
27 | nn.Conv2d(64, 128, 3, padding=1),
28 | nn.ReLU(inplace=True),
29 | nn.Conv2d(128, 128, 3, padding=1),
30 | nn.ReLU(inplace=True),
31 | nn.MaxPool2d(2, stride=2, ceil_mode=True),
32 | )
33 |
34 | self.conv_block3 = nn.Sequential(
35 | nn.Conv2d(128, 256, 3, padding=1),
36 | nn.ReLU(inplace=True),
37 | nn.Conv2d(256, 256, 3, padding=1),
38 | nn.ReLU(inplace=True),
39 | nn.Conv2d(256, 256, 3, padding=1),
40 | nn.ReLU(inplace=True),
41 | nn.MaxPool2d(2, stride=2, ceil_mode=True),
42 | )
43 |
44 | self.conv_block4 = nn.Sequential(
45 | nn.Conv2d(256, 512, 3, padding=1),
46 | nn.ReLU(inplace=True),
47 | nn.Conv2d(512, 512, 3, padding=1),
48 | nn.ReLU(inplace=True),
49 | nn.Conv2d(512, 512, 3, padding=1),
50 | nn.ReLU(inplace=True),
51 | nn.MaxPool2d(2, stride=2, ceil_mode=True),
52 | )
53 |
54 | self.conv_block5 = nn.Sequential(
55 | nn.Conv2d(512, 512, 3, padding=1),
56 | nn.ReLU(inplace=True),
57 | nn.Conv2d(512, 512, 3, padding=1),
58 | nn.ReLU(inplace=True),
59 | nn.Conv2d(512, 512, 3, padding=1),
60 | nn.ReLU(inplace=True),
61 | nn.MaxPool2d(2, stride=2, ceil_mode=True),
62 | )
63 |
64 | self.classifier = nn.Sequential(
65 | nn.Conv2d(512, 4096, 7),
66 | nn.ReLU(inplace=True),
67 | nn.Dropout2d(),
68 | nn.Conv2d(4096, 4096, 1),
69 | nn.ReLU(inplace=True),
70 | nn.Dropout2d(),
71 | nn.Conv2d(4096, self.n_classes, 1),
72 | )
73 |
74 | if self.learned_billinear:
75 | raise NotImplementedError
76 |
77 | def forward(self, x):
78 | conv1 = self.conv_block1(x)
79 | conv2 = self.conv_block2(conv1)
80 | conv3 = self.conv_block3(conv2)
81 | conv4 = self.conv_block4(conv3)
82 | conv5 = self.conv_block5(conv4)
83 |
84 | score = self.classifier(conv5)
85 |
86 | out = F.upsample(score, x.size()[2:])
87 |
88 | return out
89 |
90 | def init_vgg16_params(self, vgg16, copy_fc8=True):
91 | blocks = [
92 | self.conv_block1,
93 | self.conv_block2,
94 | self.conv_block3,
95 | self.conv_block4,
96 | self.conv_block5,
97 | ]
98 |
99 | ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]]
100 | features = list(vgg16.features.children())
101 |
102 | for idx, conv_block in enumerate(blocks):
103 | for l1, l2 in zip(features[ranges[idx][0] : ranges[idx][1]], conv_block):
104 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
105 | assert l1.weight.size() == l2.weight.size()
106 | assert l1.bias.size() == l2.bias.size()
107 | l2.weight.data = l1.weight.data
108 | l2.bias.data = l1.bias.data
109 | for i1, i2 in zip([0, 3], [0, 3]):
110 | l1 = vgg16.classifier[i1]
111 | l2 = self.classifier[i2]
112 | l2.weight.data = l1.weight.data.view(l2.weight.size())
113 | l2.bias.data = l1.bias.data.view(l2.bias.size())
114 | n_class = self.classifier[6].weight.size()[0]
115 | if copy_fc8:
116 | l1 = vgg16.classifier[6]
117 | l2 = self.classifier[6]
118 | l2.weight.data = l1.weight.data[:n_class, :].view(l2.weight.size())
119 | l2.bias.data = l1.bias.data[:n_class]
120 |
121 |
122 | class fcn16s(nn.Module):
123 | def __init__(self, n_classes=21, learned_billinear=False):
124 | super(fcn16s, self).__init__()
125 | self.learned_billinear = learned_billinear
126 | self.n_classes = n_classes
127 | self.loss = functools.partial(cross_entropy2d, size_average=False)
128 |
129 | self.conv_block1 = nn.Sequential(
130 | nn.Conv2d(3, 64, 3, padding=100),
131 | nn.ReLU(inplace=True),
132 | nn.Conv2d(64, 64, 3, padding=1),
133 | nn.ReLU(inplace=True),
134 | nn.MaxPool2d(2, stride=2, ceil_mode=True),
135 | )
136 |
137 | self.conv_block2 = nn.Sequential(
138 | nn.Conv2d(64, 128, 3, padding=1),
139 | nn.ReLU(inplace=True),
140 | nn.Conv2d(128, 128, 3, padding=1),
141 | nn.ReLU(inplace=True),
142 | nn.MaxPool2d(2, stride=2, ceil_mode=True),
143 | )
144 |
145 | self.conv_block3 = nn.Sequential(
146 | nn.Conv2d(128, 256, 3, padding=1),
147 | nn.ReLU(inplace=True),
148 | nn.Conv2d(256, 256, 3, padding=1),
149 | nn.ReLU(inplace=True),
150 | nn.Conv2d(256, 256, 3, padding=1),
151 | nn.ReLU(inplace=True),
152 | nn.MaxPool2d(2, stride=2, ceil_mode=True),
153 | )
154 |
155 | self.conv_block4 = nn.Sequential(
156 | nn.Conv2d(256, 512, 3, padding=1),
157 | nn.ReLU(inplace=True),
158 | nn.Conv2d(512, 512, 3, padding=1),
159 | nn.ReLU(inplace=True),
160 | nn.Conv2d(512, 512, 3, padding=1),
161 | nn.ReLU(inplace=True),
162 | nn.MaxPool2d(2, stride=2, ceil_mode=True),
163 | )
164 |
165 | self.conv_block5 = nn.Sequential(
166 | nn.Conv2d(512, 512, 3, padding=1),
167 | nn.ReLU(inplace=True),
168 | nn.Conv2d(512, 512, 3, padding=1),
169 | nn.ReLU(inplace=True),
170 | nn.Conv2d(512, 512, 3, padding=1),
171 | nn.ReLU(inplace=True),
172 | nn.MaxPool2d(2, stride=2, ceil_mode=True),
173 | )
174 |
175 | self.classifier = nn.Sequential(
176 | nn.Conv2d(512, 4096, 7),
177 | nn.ReLU(inplace=True),
178 | nn.Dropout2d(),
179 | nn.Conv2d(4096, 4096, 1),
180 | nn.ReLU(inplace=True),
181 | nn.Dropout2d(),
182 | nn.Conv2d(4096, self.n_classes, 1),
183 | )
184 |
185 | self.score_pool4 = nn.Conv2d(512, self.n_classes, 1)
186 |
187 | # TODO: Add support for learned upsampling
188 | if self.learned_billinear:
189 | raise NotImplementedError
190 |
191 | def forward(self, x):
192 | conv1 = self.conv_block1(x)
193 | conv2 = self.conv_block2(conv1)
194 | conv3 = self.conv_block3(conv2)
195 | conv4 = self.conv_block4(conv3)
196 | conv5 = self.conv_block5(conv4)
197 |
198 | score = self.classifier(conv5)
199 | score_pool4 = self.score_pool4(conv4)
200 |
201 | score = F.upsample(score, score_pool4.size()[2:])
202 | score += score_pool4
203 | out = F.upsample(score, x.size()[2:])
204 |
205 | return out
206 |
207 | def init_vgg16_params(self, vgg16, copy_fc8=True):
208 | blocks = [
209 | self.conv_block1,
210 | self.conv_block2,
211 | self.conv_block3,
212 | self.conv_block4,
213 | self.conv_block5,
214 | ]
215 |
216 | ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]]
217 | features = list(vgg16.features.children())
218 |
219 | for idx, conv_block in enumerate(blocks):
220 | for l1, l2 in zip(features[ranges[idx][0] : ranges[idx][1]], conv_block):
221 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
222 | # print(idx, l1, l2)
223 | assert l1.weight.size() == l2.weight.size()
224 | assert l1.bias.size() == l2.bias.size()
225 | l2.weight.data = l1.weight.data
226 | l2.bias.data = l1.bias.data
227 | for i1, i2 in zip([0, 3], [0, 3]):
228 | l1 = vgg16.classifier[i1]
229 | l2 = self.classifier[i2]
230 | l2.weight.data = l1.weight.data.view(l2.weight.size())
231 | l2.bias.data = l1.bias.data.view(l2.bias.size())
232 | n_class = self.classifier[6].weight.size()[0]
233 | if copy_fc8:
234 | l1 = vgg16.classifier[6]
235 | l2 = self.classifier[6]
236 | l2.weight.data = l1.weight.data[:n_class, :].view(l2.weight.size())
237 | l2.bias.data = l1.bias.data[:n_class]
238 |
239 |
240 | # FCN 8s
241 | class fcn8s(nn.Module):
242 | def __init__(self, n_classes=21, learned_billinear=True):
243 | super(fcn8s, self).__init__()
244 | self.learned_billinear = learned_billinear
245 | self.n_classes = n_classes
246 | self.loss = functools.partial(cross_entropy2d, size_average=False)
247 |
248 | self.conv_block1 = nn.Sequential(
249 | nn.Conv2d(3, 64, 3, padding=100),
250 | nn.ReLU(inplace=True),
251 | nn.Conv2d(64, 64, 3, padding=1),
252 | nn.ReLU(inplace=True),
253 | nn.MaxPool2d(2, stride=2, ceil_mode=True),
254 | )
255 |
256 | self.conv_block2 = nn.Sequential(
257 | nn.Conv2d(64, 128, 3, padding=1),
258 | nn.ReLU(inplace=True),
259 | nn.Conv2d(128, 128, 3, padding=1),
260 | nn.ReLU(inplace=True),
261 | nn.MaxPool2d(2, stride=2, ceil_mode=True),
262 | )
263 |
264 | self.conv_block3 = nn.Sequential(
265 | nn.Conv2d(128, 256, 3, padding=1),
266 | nn.ReLU(inplace=True),
267 | nn.Conv2d(256, 256, 3, padding=1),
268 | nn.ReLU(inplace=True),
269 | nn.Conv2d(256, 256, 3, padding=1),
270 | nn.ReLU(inplace=True),
271 | nn.MaxPool2d(2, stride=2, ceil_mode=True),
272 | )
273 |
274 | self.conv_block4 = nn.Sequential(
275 | nn.Conv2d(256, 512, 3, padding=1),
276 | nn.ReLU(inplace=True),
277 | nn.Conv2d(512, 512, 3, padding=1),
278 | nn.ReLU(inplace=True),
279 | nn.Conv2d(512, 512, 3, padding=1),
280 | nn.ReLU(inplace=True),
281 | nn.MaxPool2d(2, stride=2, ceil_mode=True),
282 | )
283 |
284 | self.conv_block5 = nn.Sequential(
285 | nn.Conv2d(512, 512, 3, padding=1),
286 | nn.ReLU(inplace=True),
287 | nn.Conv2d(512, 512, 3, padding=1),
288 | nn.ReLU(inplace=True),
289 | nn.Conv2d(512, 512, 3, padding=1),
290 | nn.ReLU(inplace=True),
291 | nn.MaxPool2d(2, stride=2, ceil_mode=True),
292 | )
293 |
294 | self.classifier = nn.Sequential(
295 | nn.Conv2d(512, 4096, 7),
296 | nn.ReLU(inplace=True),
297 | nn.Dropout2d(),
298 | nn.Conv2d(4096, 4096, 1),
299 | nn.ReLU(inplace=True),
300 | nn.Dropout2d(),
301 | nn.Conv2d(4096, self.n_classes, 1),
302 | )
303 |
304 | self.score_pool4 = nn.Conv2d(512, self.n_classes, 1)
305 | self.score_pool3 = nn.Conv2d(256, self.n_classes, 1)
306 |
307 | if self.learned_billinear:
308 | self.upscore2 = nn.ConvTranspose2d(
309 | self.n_classes, self.n_classes, 4, stride=2, bias=False
310 | )
311 | self.upscore4 = nn.ConvTranspose2d(
312 | self.n_classes, self.n_classes, 4, stride=2, bias=False
313 | )
314 | self.upscore8 = nn.ConvTranspose2d(
315 | self.n_classes, self.n_classes, 16, stride=8, bias=False
316 | )
317 |
318 | for m in self.modules():
319 | if isinstance(m, nn.ConvTranspose2d):
320 | m.weight.data.copy_(
321 | get_upsampling_weight(m.in_channels, m.out_channels, m.kernel_size[0])
322 | )
323 |
324 | def forward(self, x):
325 | conv1 = self.conv_block1(x)
326 | conv2 = self.conv_block2(conv1)
327 | conv3 = self.conv_block3(conv2)
328 | conv4 = self.conv_block4(conv3)
329 | conv5 = self.conv_block5(conv4)
330 |
331 | score = self.classifier(conv5)
332 |
333 | if self.learned_billinear:
334 | upscore2 = self.upscore2(score)
335 | score_pool4c = self.score_pool4(conv4)[
336 | :, :, 5 : 5 + upscore2.size()[2], 5 : 5 + upscore2.size()[3]
337 | ]
338 | upscore_pool4 = self.upscore4(upscore2 + score_pool4c)
339 |
340 | score_pool3c = self.score_pool3(conv3)[
341 | :, :, 9 : 9 + upscore_pool4.size()[2], 9 : 9 + upscore_pool4.size()[3]
342 | ]
343 |
344 | out = self.upscore8(score_pool3c + upscore_pool4)[
345 | :, :, 31 : 31 + x.size()[2], 31 : 31 + x.size()[3]
346 | ]
347 | return out.contiguous()
348 |
349 | else:
350 | score_pool4 = self.score_pool4(conv4)
351 | score_pool3 = self.score_pool3(conv3)
352 | score = F.upsample(score, score_pool4.size()[2:])
353 | score += score_pool4
354 | score = F.upsample(score, score_pool3.size()[2:])
355 | score += score_pool3
356 | out = F.upsample(score, x.size()[2:])
357 |
358 | return out
359 |
360 | def init_vgg16_params(self, vgg16, copy_fc8=True):
361 | blocks = [
362 | self.conv_block1,
363 | self.conv_block2,
364 | self.conv_block3,
365 | self.conv_block4,
366 | self.conv_block5,
367 | ]
368 |
369 | ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]]
370 | features = list(vgg16.features.children())
371 |
372 | for idx, conv_block in enumerate(blocks):
373 | for l1, l2 in zip(features[ranges[idx][0] : ranges[idx][1]], conv_block):
374 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
375 | assert l1.weight.size() == l2.weight.size()
376 | assert l1.bias.size() == l2.bias.size()
377 | l2.weight.data = l1.weight.data
378 | l2.bias.data = l1.bias.data
379 | for i1, i2 in zip([0, 3], [0, 3]):
380 | l1 = vgg16.classifier[i1]
381 | l2 = self.classifier[i2]
382 | l2.weight.data = l1.weight.data.view(l2.weight.size())
383 | l2.bias.data = l1.bias.data.view(l2.bias.size())
384 | n_class = self.classifier[6].weight.size()[0]
385 | if copy_fc8:
386 | l1 = vgg16.classifier[6]
387 | l2 = self.classifier[6]
388 | l2.weight.data = l1.weight.data[:n_class, :].view(l2.weight.size())
389 | l2.bias.data = l1.bias.data[:n_class]
390 |
--------------------------------------------------------------------------------
/ptsemseg/models/frrn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from ptsemseg.models.utils import FRRU, RU, conv2DBatchNormRelu, conv2DGroupNormRelu
6 |
7 | frrn_specs_dic = {
8 | "A": {
9 | "encoder": [[3, 96, 2], [4, 192, 4], [2, 384, 8], [2, 384, 16]],
10 | "decoder": [[2, 192, 8], [2, 192, 4], [2, 48, 2]],
11 | },
12 | "B": {
13 | "encoder": [[3, 96, 2], [4, 192, 4], [2, 384, 8], [2, 384, 16], [2, 384, 32]],
14 | "decoder": [[2, 192, 16], [2, 192, 8], [2, 192, 4], [2, 48, 2]],
15 | },
16 | }
17 |
18 |
19 | class frrn(nn.Module):
20 | """
21 | Full Resolution Residual Networks for Semantic Segmentation
22 | URL: https://arxiv.org/abs/1611.08323
23 |
24 | References:
25 | 1) Original Author's code: https://github.com/TobyPDE/FRRN
26 | 2) TF implementation by @kiwonjoon: https://github.com/hiwonjoon/tf-frrn
27 | """
28 |
29 | def __init__(self, n_classes=21, model_type="B", group_norm=False, n_groups=16):
30 | super(frrn, self).__init__()
31 | self.n_classes = n_classes
32 | self.model_type = model_type
33 | self.group_norm = group_norm
34 | self.n_groups = n_groups
35 |
36 | if self.group_norm:
37 | self.conv1 = conv2DGroupNormRelu(3, 48, 5, 1, 2)
38 | else:
39 | self.conv1 = conv2DBatchNormRelu(3, 48, 5, 1, 2)
40 |
41 | self.up_residual_units = []
42 | self.down_residual_units = []
43 | for i in range(3):
44 | self.up_residual_units.append(
45 | RU(
46 | channels=48,
47 | kernel_size=3,
48 | strides=1,
49 | group_norm=self.group_norm,
50 | n_groups=self.n_groups,
51 | )
52 | )
53 | self.down_residual_units.append(
54 | RU(
55 | channels=48,
56 | kernel_size=3,
57 | strides=1,
58 | group_norm=self.group_norm,
59 | n_groups=self.n_groups,
60 | )
61 | )
62 |
63 | self.up_residual_units = nn.ModuleList(self.up_residual_units)
64 | self.down_residual_units = nn.ModuleList(self.down_residual_units)
65 |
66 | self.split_conv = nn.Conv2d(48, 32, kernel_size=1, padding=0, stride=1, bias=False)
67 |
68 | # each spec is as (n_blocks, channels, scale)
69 | self.encoder_frru_specs = frrn_specs_dic[self.model_type]["encoder"]
70 |
71 | self.decoder_frru_specs = frrn_specs_dic[self.model_type]["decoder"]
72 |
73 | # encoding
74 | prev_channels = 48
75 | self.encoding_frrus = {}
76 | for n_blocks, channels, scale in self.encoder_frru_specs:
77 | for block in range(n_blocks):
78 | key = "_".join(map(str, ["encoding_frru", n_blocks, channels, scale, block]))
79 | setattr(
80 | self,
81 | key,
82 | FRRU(
83 | prev_channels=prev_channels,
84 | out_channels=channels,
85 | scale=scale,
86 | group_norm=self.group_norm,
87 | n_groups=self.n_groups,
88 | ),
89 | )
90 | prev_channels = channels
91 |
92 | # decoding
93 | self.decoding_frrus = {}
94 | for n_blocks, channels, scale in self.decoder_frru_specs:
95 | # pass through decoding FRRUs
96 | for block in range(n_blocks):
97 | key = "_".join(map(str, ["decoding_frru", n_blocks, channels, scale, block]))
98 | setattr(
99 | self,
100 | key,
101 | FRRU(
102 | prev_channels=prev_channels,
103 | out_channels=channels,
104 | scale=scale,
105 | group_norm=self.group_norm,
106 | n_groups=self.n_groups,
107 | ),
108 | )
109 | prev_channels = channels
110 |
111 | self.merge_conv = nn.Conv2d(
112 | prev_channels + 32, 48, kernel_size=1, padding=0, stride=1, bias=False
113 | )
114 |
115 | self.classif_conv = nn.Conv2d(
116 | 48, self.n_classes, kernel_size=1, padding=0, stride=1, bias=True
117 | )
118 |
119 | def forward(self, x):
120 |
121 | # pass to initial conv
122 | x = self.conv1(x)
123 |
124 | # pass through residual units
125 | for i in range(3):
126 | x = self.up_residual_units[i](x)
127 |
128 | # divide stream
129 | y = x
130 | z = self.split_conv(x)
131 |
132 | prev_channels = 48
133 | # encoding
134 | for n_blocks, channels, scale in self.encoder_frru_specs:
135 | # maxpool bigger feature map
136 | y_pooled = F.max_pool2d(y, stride=2, kernel_size=2, padding=0)
137 | # pass through encoding FRRUs
138 | for block in range(n_blocks):
139 | key = "_".join(map(str, ["encoding_frru", n_blocks, channels, scale, block]))
140 | y, z = getattr(self, key)(y_pooled, z)
141 | prev_channels = channels
142 |
143 | # decoding
144 | for n_blocks, channels, scale in self.decoder_frru_specs:
145 | # bilinear upsample smaller feature map
146 | upsample_size = torch.Size([_s * 2 for _s in y.size()[-2:]])
147 | y_upsampled = F.upsample(y, size=upsample_size, mode="bilinear", align_corners=True)
148 | # pass through decoding FRRUs
149 | for block in range(n_blocks):
150 | key = "_".join(map(str, ["decoding_frru", n_blocks, channels, scale, block]))
151 | # print("Incoming FRRU Size: ", key, y_upsampled.shape, z.shape)
152 | y, z = getattr(self, key)(y_upsampled, z)
153 | # print("Outgoing FRRU Size: ", key, y.shape, z.shape)
154 | prev_channels = channels
155 |
156 | # merge streams
157 | x = torch.cat(
158 | [F.upsample(y, scale_factor=2, mode="bilinear", align_corners=True), z], dim=1
159 | )
160 | x = self.merge_conv(x)
161 |
162 | # pass through residual units
163 | for i in range(3):
164 | x = self.down_residual_units[i](x)
165 |
166 | # final 1x1 conv to get classification
167 | x = self.classif_conv(x)
168 |
169 | return x
170 |
--------------------------------------------------------------------------------
/ptsemseg/models/icnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torch.autograd import Variable
7 |
8 | from ptsemseg import caffe_pb2
9 | from ptsemseg.models.utils import (
10 | get_interp_size,
11 | cascadeFeatureFusion,
12 | conv2DBatchNormRelu,
13 | residualBlockPSP,
14 | pyramidPooling,
15 | )
16 | from ptsemseg.loss.loss import multi_scale_cross_entropy2d
17 |
18 | icnet_specs = {
19 | "cityscapes": {"n_classes": 19, "input_size": (1025, 2049), "block_config": [3, 4, 6, 3]}
20 | }
21 |
22 |
23 | class icnet(nn.Module):
24 |
25 | """
26 | Image Cascade Network
27 | URL: https://arxiv.org/abs/1704.08545
28 |
29 | References:
30 | 1) Original Author's code: https://github.com/hszhao/ICNet
31 | 2) Chainer implementation by @mitmul: https://github.com/mitmul/chainer-pspnet
32 | 3) TensorFlow implementation by @hellochick: https://github.com/hellochick/ICNet-tensorflow
33 |
34 | """
35 |
36 | def __init__(
37 | self,
38 | n_classes=19,
39 | block_config=[3, 4, 6, 3],
40 | input_size=(1025, 2049),
41 | version=None,
42 | is_batchnorm=True,
43 | ):
44 |
45 | super(icnet, self).__init__()
46 |
47 | bias = not is_batchnorm
48 |
49 | self.block_config = (
50 | icnet_specs[version]["block_config"] if version is not None else block_config
51 | )
52 | self.n_classes = icnet_specs[version]["n_classes"] if version is not None else n_classes
53 | self.input_size = icnet_specs[version]["input_size"] if version is not None else input_size
54 |
55 | # Encoder
56 | self.convbnrelu1_1 = conv2DBatchNormRelu(
57 | in_channels=3,
58 | k_size=3,
59 | n_filters=32,
60 | padding=1,
61 | stride=2,
62 | bias=bias,
63 | is_batchnorm=is_batchnorm,
64 | )
65 | self.convbnrelu1_2 = conv2DBatchNormRelu(
66 | in_channels=32,
67 | k_size=3,
68 | n_filters=32,
69 | padding=1,
70 | stride=1,
71 | bias=bias,
72 | is_batchnorm=is_batchnorm,
73 | )
74 | self.convbnrelu1_3 = conv2DBatchNormRelu(
75 | in_channels=32,
76 | k_size=3,
77 | n_filters=64,
78 | padding=1,
79 | stride=1,
80 | bias=bias,
81 | is_batchnorm=is_batchnorm,
82 | )
83 |
84 | # Vanilla Residual Blocks
85 | self.res_block2 = residualBlockPSP(
86 | self.block_config[0], 64, 32, 128, 1, 1, is_batchnorm=is_batchnorm
87 | )
88 | self.res_block3_conv = residualBlockPSP(
89 | self.block_config[1],
90 | 128,
91 | 64,
92 | 256,
93 | 2,
94 | 1,
95 | include_range="conv",
96 | is_batchnorm=is_batchnorm,
97 | )
98 | self.res_block3_identity = residualBlockPSP(
99 | self.block_config[1],
100 | 128,
101 | 64,
102 | 256,
103 | 2,
104 | 1,
105 | include_range="identity",
106 | is_batchnorm=is_batchnorm,
107 | )
108 |
109 | # Dilated Residual Blocks
110 | self.res_block4 = residualBlockPSP(
111 | self.block_config[2], 256, 128, 512, 1, 2, is_batchnorm=is_batchnorm
112 | )
113 | self.res_block5 = residualBlockPSP(
114 | self.block_config[3], 512, 256, 1024, 1, 4, is_batchnorm=is_batchnorm
115 | )
116 |
117 | # Pyramid Pooling Module
118 | self.pyramid_pooling = pyramidPooling(
119 | 1024, [6, 3, 2, 1], model_name="icnet", fusion_mode="sum", is_batchnorm=is_batchnorm
120 | )
121 |
122 | # Final conv layer with kernel 1 in sub4 branch
123 | self.conv5_4_k1 = conv2DBatchNormRelu(
124 | in_channels=1024,
125 | k_size=1,
126 | n_filters=256,
127 | padding=0,
128 | stride=1,
129 | bias=bias,
130 | is_batchnorm=is_batchnorm,
131 | )
132 |
133 | # High-resolution (sub1) branch
134 | self.convbnrelu1_sub1 = conv2DBatchNormRelu(
135 | in_channels=3,
136 | k_size=3,
137 | n_filters=32,
138 | padding=1,
139 | stride=2,
140 | bias=bias,
141 | is_batchnorm=is_batchnorm,
142 | )
143 | self.convbnrelu2_sub1 = conv2DBatchNormRelu(
144 | in_channels=32,
145 | k_size=3,
146 | n_filters=32,
147 | padding=1,
148 | stride=2,
149 | bias=bias,
150 | is_batchnorm=is_batchnorm,
151 | )
152 | self.convbnrelu3_sub1 = conv2DBatchNormRelu(
153 | in_channels=32,
154 | k_size=3,
155 | n_filters=64,
156 | padding=1,
157 | stride=2,
158 | bias=bias,
159 | is_batchnorm=is_batchnorm,
160 | )
161 | self.classification = nn.Conv2d(128, self.n_classes, 1, 1, 0)
162 |
163 | # Cascade Feature Fusion Units
164 | self.cff_sub24 = cascadeFeatureFusion(
165 | self.n_classes, 256, 256, 128, is_batchnorm=is_batchnorm
166 | )
167 | self.cff_sub12 = cascadeFeatureFusion(
168 | self.n_classes, 128, 64, 128, is_batchnorm=is_batchnorm
169 | )
170 |
171 | # Define auxiliary loss function
172 | self.loss = multi_scale_cross_entropy2d
173 |
174 | def forward(self, x):
175 | h, w = x.shape[2:]
176 |
177 | # H, W -> H/2, W/2
178 | x_sub2 = F.interpolate(
179 | x, size=get_interp_size(x, s_factor=2), mode="bilinear", align_corners=True
180 | )
181 |
182 | # H/2, W/2 -> H/4, W/4
183 | x_sub2 = self.convbnrelu1_1(x_sub2)
184 | x_sub2 = self.convbnrelu1_2(x_sub2)
185 | x_sub2 = self.convbnrelu1_3(x_sub2)
186 |
187 | # H/4, W/4 -> H/8, W/8
188 | x_sub2 = F.max_pool2d(x_sub2, 3, 2, 1)
189 |
190 | # H/8, W/8 -> H/16, W/16
191 | x_sub2 = self.res_block2(x_sub2)
192 | x_sub2 = self.res_block3_conv(x_sub2)
193 | # H/16, W/16 -> H/32, W/32
194 | x_sub4 = F.interpolate(
195 | x_sub2, size=get_interp_size(x_sub2, s_factor=2), mode="bilinear", align_corners=True
196 | )
197 | x_sub4 = self.res_block3_identity(x_sub4)
198 |
199 | x_sub4 = self.res_block4(x_sub4)
200 | x_sub4 = self.res_block5(x_sub4)
201 |
202 | x_sub4 = self.pyramid_pooling(x_sub4)
203 | x_sub4 = self.conv5_4_k1(x_sub4)
204 |
205 | x_sub1 = self.convbnrelu1_sub1(x)
206 | x_sub1 = self.convbnrelu2_sub1(x_sub1)
207 | x_sub1 = self.convbnrelu3_sub1(x_sub1)
208 |
209 | x_sub24, sub4_cls = self.cff_sub24(x_sub4, x_sub2)
210 | x_sub12, sub24_cls = self.cff_sub12(x_sub24, x_sub1)
211 |
212 | x_sub12 = F.interpolate(
213 | x_sub12, size=get_interp_size(x_sub12, z_factor=2), mode="bilinear", align_corners=True
214 | )
215 | x_sub4 = self.res_block3_identity(x_sub4)
216 | sub124_cls = self.classification(x_sub12)
217 |
218 | if self.training:
219 | return (sub124_cls, sub24_cls, sub4_cls)
220 | else:
221 | sub124_cls = F.interpolate(
222 | sub124_cls,
223 | size=get_interp_size(sub124_cls, z_factor=4),
224 | mode="bilinear",
225 | align_corners=True,
226 | )
227 | return sub124_cls
228 |
229 | def load_pretrained_model(self, model_path):
230 | """
231 | Load weights from caffemodel w/o caffe dependency
232 | and plug them in corresponding modules
233 | """
234 | # My eyes and my heart both hurt when writing this method
235 |
236 | # Only care about layer_types that have trainable parameters
237 | ltypes = [
238 | "BNData",
239 | "ConvolutionData",
240 | "HoleConvolutionData",
241 | "Convolution",
242 | ] # Convolution type for conv3_sub1_proj
243 |
244 | def _get_layer_params(layer, ltype):
245 |
246 | if ltype == "BNData":
247 | gamma = np.array(layer.blobs[0].data)
248 | beta = np.array(layer.blobs[1].data)
249 | mean = np.array(layer.blobs[2].data)
250 | var = np.array(layer.blobs[3].data)
251 | return [mean, var, gamma, beta]
252 |
253 | elif ltype in ["ConvolutionData", "HoleConvolutionData", "Convolution"]:
254 | is_bias = layer.convolution_param.bias_term
255 | weights = np.array(layer.blobs[0].data)
256 | bias = []
257 | if is_bias:
258 | bias = np.array(layer.blobs[1].data)
259 | return [weights, bias]
260 |
261 | elif ltype == "InnerProduct":
262 | raise Exception("Fully connected layers {}, not supported".format(ltype))
263 |
264 | else:
265 | raise Exception("Unkown layer type {}".format(ltype))
266 |
267 | net = caffe_pb2.NetParameter()
268 | with open(model_path, "rb") as model_file:
269 | net.MergeFromString(model_file.read())
270 |
271 | # dict formatted as -> key: :: value:
272 | layer_types = {}
273 | # dict formatted as -> key: :: value:[]
274 | layer_params = {}
275 |
276 | for l in net.layer:
277 | lname = l.name
278 | ltype = l.type
279 | lbottom = l.bottom
280 | ltop = l.top
281 | if ltype in ltypes:
282 | print("Processing layer {} | {}, {}".format(lname, lbottom, ltop))
283 | layer_types[lname] = ltype
284 | layer_params[lname] = _get_layer_params(l, ltype)
285 | # if len(l.blobs) > 0:
286 | # print(lname, ltype, lbottom, ltop, len(l.blobs))
287 |
288 | # Set affine=False for all batchnorm modules
289 | def _no_affine_bn(module=None):
290 | if isinstance(module, nn.BatchNorm2d):
291 | module.affine = False
292 |
293 | if len([m for m in module.children()]) > 0:
294 | for child in module.children():
295 | _no_affine_bn(child)
296 |
297 | # _no_affine_bn(self)
298 |
299 | def _transfer_conv(layer_name, module):
300 | weights, bias = layer_params[layer_name]
301 | w_shape = np.array(module.weight.size())
302 |
303 | print(
304 | "CONV {}: Original {} and trans weights {}".format(
305 | layer_name, w_shape, weights.shape
306 | )
307 | )
308 |
309 | module.weight.data.copy_(torch.from_numpy(weights).view_as(module.weight))
310 |
311 | if len(bias) != 0:
312 | b_shape = np.array(module.bias.size())
313 | print(
314 | "CONV {}: Original {} and trans bias {}".format(layer_name, b_shape, bias.shape)
315 | )
316 | module.bias.data.copy_(torch.from_numpy(bias).view_as(module.bias))
317 |
318 | def _transfer_bn(conv_layer_name, bn_module):
319 | mean, var, gamma, beta = layer_params[conv_layer_name + "/bn"]
320 | print(
321 | "BN {}: Original {} and trans weights {}".format(
322 | conv_layer_name, bn_module.running_mean.size(), mean.shape
323 | )
324 | )
325 | bn_module.running_mean.copy_(torch.from_numpy(mean).view_as(bn_module.running_mean))
326 | bn_module.running_var.copy_(torch.from_numpy(var).view_as(bn_module.running_var))
327 | bn_module.weight.data.copy_(torch.from_numpy(gamma).view_as(bn_module.weight))
328 | bn_module.bias.data.copy_(torch.from_numpy(beta).view_as(bn_module.bias))
329 |
330 | def _transfer_conv_bn(conv_layer_name, mother_module):
331 | conv_module = mother_module[0]
332 | _transfer_conv(conv_layer_name, conv_module)
333 |
334 | if conv_layer_name + "/bn" in layer_params.keys():
335 | bn_module = mother_module[1]
336 | _transfer_bn(conv_layer_name, bn_module)
337 |
338 | def _transfer_residual(block_name, block):
339 | block_module, n_layers = block[0], block[1]
340 | prefix = block_name[:5]
341 |
342 | if ("bottleneck" in block_name) or ("identity" not in block_name): # Conv block
343 | bottleneck = block_module.layers[0]
344 | bottleneck_conv_bn_dic = {
345 | prefix + "_1_1x1_reduce": bottleneck.cbr1.cbr_unit,
346 | prefix + "_1_3x3": bottleneck.cbr2.cbr_unit,
347 | prefix + "_1_1x1_proj": bottleneck.cb4.cb_unit,
348 | prefix + "_1_1x1_increase": bottleneck.cb3.cb_unit,
349 | }
350 |
351 | for k, v in bottleneck_conv_bn_dic.items():
352 | _transfer_conv_bn(k, v)
353 |
354 | if ("identity" in block_name) or ("bottleneck" not in block_name): # Identity blocks
355 | base_idx = 2 if "identity" in block_name else 1
356 |
357 | for layer_idx in range(2, n_layers + 1):
358 | residual_layer = block_module.layers[layer_idx - base_idx]
359 | residual_conv_bn_dic = {
360 | "_".join(
361 | map(str, [prefix, layer_idx, "1x1_reduce"])
362 | ): residual_layer.cbr1.cbr_unit,
363 | "_".join(
364 | map(str, [prefix, layer_idx, "3x3"])
365 | ): residual_layer.cbr2.cbr_unit,
366 | "_".join(
367 | map(str, [prefix, layer_idx, "1x1_increase"])
368 | ): residual_layer.cb3.cb_unit,
369 | }
370 |
371 | for k, v in residual_conv_bn_dic.items():
372 | _transfer_conv_bn(k, v)
373 |
374 | convbn_layer_mapping = {
375 | "conv1_1_3x3_s2": self.convbnrelu1_1.cbr_unit,
376 | "conv1_2_3x3": self.convbnrelu1_2.cbr_unit,
377 | "conv1_3_3x3": self.convbnrelu1_3.cbr_unit,
378 | "conv1_sub1": self.convbnrelu1_sub1.cbr_unit,
379 | "conv2_sub1": self.convbnrelu2_sub1.cbr_unit,
380 | "conv3_sub1": self.convbnrelu3_sub1.cbr_unit,
381 | # 'conv5_3_pool6_conv': self.pyramid_pooling.paths[0].cbr_unit,
382 | # 'conv5_3_pool3_conv': self.pyramid_pooling.paths[1].cbr_unit,
383 | # 'conv5_3_pool2_conv': self.pyramid_pooling.paths[2].cbr_unit,
384 | # 'conv5_3_pool1_conv': self.pyramid_pooling.paths[3].cbr_unit,
385 | "conv5_4_k1": self.conv5_4_k1.cbr_unit,
386 | "conv_sub4": self.cff_sub24.low_dilated_conv_bn.cb_unit,
387 | "conv3_1_sub2_proj": self.cff_sub24.high_proj_conv_bn.cb_unit,
388 | "conv_sub2": self.cff_sub12.low_dilated_conv_bn.cb_unit,
389 | "conv3_sub1_proj": self.cff_sub12.high_proj_conv_bn.cb_unit,
390 | }
391 |
392 | residual_layers = {
393 | "conv2": [self.res_block2, self.block_config[0]],
394 | "conv3_bottleneck": [self.res_block3_conv, self.block_config[1]],
395 | "conv3_identity": [self.res_block3_identity, self.block_config[1]],
396 | "conv4": [self.res_block4, self.block_config[2]],
397 | "conv5": [self.res_block5, self.block_config[3]],
398 | }
399 |
400 | # Transfer weights for all non-residual conv+bn layers
401 | for k, v in convbn_layer_mapping.items():
402 | _transfer_conv_bn(k, v)
403 |
404 | # Transfer weights for final non-bn conv layer
405 | _transfer_conv("conv6_cls", self.classification)
406 | _transfer_conv("conv6_sub4", self.cff_sub24.low_classifier_conv)
407 | _transfer_conv("conv6_sub2", self.cff_sub12.low_classifier_conv)
408 |
409 | # Transfer weights for all residual layers
410 | for k, v in residual_layers.items():
411 | _transfer_residual(k, v)
412 |
413 | def tile_predict(self, imgs, include_flip_mode=True):
414 | """
415 | Predict by takin overlapping tiles from the image.
416 |
417 | Strides are adaptively computed from the imgs shape
418 | and input size
419 |
420 | :param imgs: torch.Tensor with shape [N, C, H, W] in BGR format
421 | :param side: int with side length of model input
422 | :param n_classes: int with number of classes in seg output.
423 | """
424 |
425 | side_x, side_y = self.input_size
426 | n_classes = self.n_classes
427 | n_samples, c, h, w = imgs.shape
428 | # n = int(max(h,w) / float(side) + 1)
429 | n_x = int(h / float(side_x) + 1)
430 | n_y = int(w / float(side_y) + 1)
431 | stride_x = (h - side_x) / float(n_x)
432 | stride_y = (w - side_y) / float(n_y)
433 |
434 | x_ends = [[int(i * stride_x), int(i * stride_x) + side_x] for i in range(n_x + 1)]
435 | y_ends = [[int(i * stride_y), int(i * stride_y) + side_y] for i in range(n_y + 1)]
436 |
437 | pred = np.zeros([n_samples, n_classes, h, w])
438 | count = np.zeros([h, w])
439 |
440 | slice_count = 0
441 | for sx, ex in x_ends:
442 | for sy, ey in y_ends:
443 | slice_count += 1
444 |
445 | imgs_slice = imgs[:, :, sx:ex, sy:ey]
446 | if include_flip_mode:
447 | imgs_slice_flip = torch.from_numpy(
448 | np.copy(imgs_slice.cpu().numpy()[:, :, :, ::-1])
449 | ).float()
450 |
451 | is_model_on_cuda = next(self.parameters()).is_cuda
452 |
453 | inp = Variable(imgs_slice, volatile=True)
454 | if include_flip_mode:
455 | flp = Variable(imgs_slice_flip, volatile=True)
456 |
457 | if is_model_on_cuda:
458 | inp = inp.cuda()
459 | if include_flip_mode:
460 | flp = flp.cuda()
461 |
462 | psub1 = F.softmax(self.forward(inp), dim=1).data.cpu().numpy()
463 | if include_flip_mode:
464 | psub2 = F.softmax(self.forward(flp), dim=1).data.cpu().numpy()
465 | psub = (psub1 + psub2[:, :, :, ::-1]) / 2.0
466 | else:
467 | psub = psub1
468 |
469 | pred[:, :, sx:ex, sy:ey] = psub
470 | count[sx:ex, sy:ey] += 1.0
471 |
472 | score = (pred / count[None, None, ...]).astype(np.float32)
473 | return score / np.expand_dims(score.sum(axis=1), axis=1)
474 |
475 |
476 | # For Testing Purposes only
477 | if __name__ == "__main__":
478 | cd = 0
479 | import os
480 | import scipy.misc as m
481 | from ptsemseg.loader.cityscapes_loader import cityscapesLoader as cl
482 |
483 | ic = icnet(version="cityscapes", is_batchnorm=False)
484 |
485 | # Just need to do this one time
486 | caffemodel_dir_path = "PATH_TO_ICNET_DIR/evaluation/model"
487 | ic.load_pretrained_model(
488 | model_path=os.path.join(caffemodel_dir_path, "icnet_cityscapes_train_30k.caffemodel")
489 | )
490 | # ic.load_pretrained_model(model_path=os.path.join(caffemodel_dir_path,
491 | # 'icnet_cityscapes_train_30k_bnnomerge.caffemodel'))
492 | # ic.load_pretrained_model(model_path=os.path.join(caffemodel_dir_path,
493 | # 'icnet_cityscapes_trainval_90k.caffemodel'))
494 | # ic.load_pretrained_model(model_path=os.path.join(caffemodel_dir_path,
495 | # 'icnet_cityscapes_trainval_90k_bnnomerge.caffemodel'))
496 |
497 | # ic.load_state_dict(torch.load('ic.pth'))
498 |
499 | ic.float()
500 | ic.cuda(cd)
501 | ic.eval()
502 |
503 | dataset_root_dir = "PATH_TO_CITYSCAPES_DIR"
504 | dst = cl(root=dataset_root_dir)
505 | img = m.imread(
506 | os.path.join(
507 | dataset_root_dir,
508 | "leftImg8bit/demoVideo/stuttgart_00/stuttgart_00_000000_000010_leftImg8bit.png",
509 | )
510 | )
511 | m.imsave("test_input.png", img)
512 | orig_size = img.shape[:-1]
513 | img = m.imresize(img, ic.input_size) # uint8 with RGB mode
514 | img = img.transpose(2, 0, 1)
515 | img = img.astype(np.float64)
516 | img -= np.array([123.68, 116.779, 103.939])[:, None, None]
517 | img = np.copy(img[::-1, :, :])
518 | img = torch.from_numpy(img).float()
519 | img = img.unsqueeze(0)
520 |
521 | out = ic.tile_predict(img)
522 | pred = np.argmax(out, axis=1)[0]
523 | pred = pred.astype(np.float32)
524 | pred = m.imresize(pred, orig_size, "nearest", mode="F") # float32 with F mode
525 | decoded = dst.decode_segmap(pred)
526 | m.imsave("test_output.png", decoded)
527 | # m.imsave('test_output.png', pred)
528 |
529 | checkpoints_dir_path = "checkpoints"
530 | if not os.path.exists(checkpoints_dir_path):
531 | os.mkdir(checkpoints_dir_path)
532 | ic = torch.nn.DataParallel(ic, device_ids=range(torch.cuda.device_count()))
533 | state = {"model_state": ic.state_dict()}
534 | torch.save(state, os.path.join(checkpoints_dir_path, "icnet_cityscapes_train_30k.pth"))
535 | # torch.save(state, os.path.join(checkpoints_dir_path, "icnetBN_cityscapes_train_30k.pth"))
536 | # torch.save(state, os.path.join(checkpoints_dir_path, "icnet_cityscapes_trainval_90k.pth"))
537 | # torch.save(state, os.path.join(checkpoints_dir_path, "icnetBN_cityscapes_trainval_90k.pth"))
538 | print("Output Shape {} \t Input Shape {}".format(out.shape, img.shape))
539 |
--------------------------------------------------------------------------------
/ptsemseg/models/linknet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from ptsemseg.models.utils import conv2DBatchNormRelu, linknetUp, residualBlock
4 |
5 |
6 | class linknet(nn.Module):
7 | def __init__(
8 | self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True
9 | ):
10 | super(linknet, self).__init__()
11 | self.is_deconv = is_deconv
12 | self.in_channels = in_channels
13 | self.is_batchnorm = is_batchnorm
14 | self.feature_scale = feature_scale
15 | self.layers = [2, 2, 2, 2] # Currently hardcoded for ResNet-18
16 |
17 | filters = [64, 128, 256, 512]
18 | filters = [x / self.feature_scale for x in filters]
19 |
20 | self.inplanes = filters[0]
21 |
22 | # Encoder
23 | self.convbnrelu1 = conv2DBatchNormRelu(
24 | in_channels=3, k_size=7, n_filters=64, padding=3, stride=2, bias=False
25 | )
26 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
27 |
28 | block = residualBlock
29 | self.encoder1 = self._make_layer(block, filters[0], self.layers[0])
30 | self.encoder2 = self._make_layer(block, filters[1], self.layers[1], stride=2)
31 | self.encoder3 = self._make_layer(block, filters[2], self.layers[2], stride=2)
32 | self.encoder4 = self._make_layer(block, filters[3], self.layers[3], stride=2)
33 | self.avgpool = nn.AvgPool2d(7)
34 |
35 | # Decoder
36 | self.decoder4 = linknetUp(filters[3], filters[2])
37 | self.decoder4 = linknetUp(filters[2], filters[1])
38 | self.decoder4 = linknetUp(filters[1], filters[0])
39 | self.decoder4 = linknetUp(filters[0], filters[0])
40 |
41 | # Final Classifier
42 | self.finaldeconvbnrelu1 = nn.Sequential(
43 | nn.ConvTranspose2d(filters[0], 32 / feature_scale, 3, 2, 1),
44 | nn.BatchNorm2d(32 / feature_scale),
45 | nn.ReLU(inplace=True),
46 | )
47 | self.finalconvbnrelu2 = conv2DBatchNormRelu(
48 | in_channels=32 / feature_scale,
49 | k_size=3,
50 | n_filters=32 / feature_scale,
51 | padding=1,
52 | stride=1,
53 | )
54 | self.finalconv3 = nn.Conv2d(32 / feature_scale, n_classes, 2, 2, 0)
55 |
56 | def _make_layer(self, block, planes, blocks, stride=1):
57 | downsample = None
58 | if stride != 1 or self.inplanes != planes * block.expansion:
59 | downsample = nn.Sequential(
60 | nn.Conv2d(
61 | self.inplanes,
62 | planes * block.expansion,
63 | kernel_size=1,
64 | stride=stride,
65 | bias=False,
66 | ),
67 | nn.BatchNorm2d(planes * block.expansion),
68 | )
69 | layers = []
70 | layers.append(block(self.inplanes, planes, stride, downsample))
71 | self.inplanes = planes * block.expansion
72 | for i in range(1, blocks):
73 | layers.append(block(self.inplanes, planes))
74 | return nn.Sequential(*layers)
75 |
76 | def forward(self, x):
77 | # Encoder
78 | x = self.convbnrelu1(x)
79 | x = self.maxpool(x)
80 |
81 | e1 = self.encoder1(x)
82 | e2 = self.encoder2(e1)
83 | e3 = self.encoder3(e2)
84 | e4 = self.encoder4(e3)
85 |
86 | # Decoder with Skip Connections
87 | d4 = self.decoder4(e4)
88 | d4 += e3
89 | d3 = self.decoder3(d4)
90 | d3 += e2
91 | d2 = self.decoder2(d3)
92 | d2 += e1
93 | d1 = self.decoder1(d2)
94 |
95 | # Final Classification
96 | f1 = self.finaldeconvbnrelu1(d1)
97 | f2 = self.finalconvbnrelu2(f1)
98 | f3 = self.finalconv3(f2)
99 |
100 | return f3
101 |
--------------------------------------------------------------------------------
/ptsemseg/models/pspnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torch.autograd import Variable
7 |
8 | from ptsemseg import caffe_pb2
9 | from ptsemseg.models.utils import conv2DBatchNormRelu, residualBlockPSP, pyramidPooling
10 | from ptsemseg.loss.loss import multi_scale_cross_entropy2d
11 |
12 | pspnet_specs = {
13 | "pascal": {"n_classes": 21, "input_size": (473, 473), "block_config": [3, 4, 23, 3]},
14 | "cityscapes": {"n_classes": 19, "input_size": (713, 713), "block_config": [3, 4, 23, 3]},
15 | "ade20k": {"n_classes": 150, "input_size": (473, 473), "block_config": [3, 4, 6, 3]},
16 | }
17 |
18 |
19 | class pspnet(nn.Module):
20 |
21 | """
22 | Pyramid Scene Parsing Network
23 | URL: https://arxiv.org/abs/1612.01105
24 |
25 | References:
26 | 1) Original Author's code: https://github.com/hszhao/PSPNet
27 | 2) Chainer implementation by @mitmul: https://github.com/mitmul/chainer-pspnet
28 | 3) TensorFlow implementation by @hellochick: https://github.com/hellochick/PSPNet-tensorflow
29 |
30 | Visualization:
31 | http://dgschwend.github.io/netscope/#/gist/6bfb59e6a3cfcb4e2bb8d47f827c2928
32 |
33 | """
34 |
35 | def __init__(
36 | self, n_classes=21, block_config=[3, 4, 23, 3], input_size=(473, 473), version=None
37 | ):
38 |
39 | super(pspnet, self).__init__()
40 |
41 | self.block_config = (
42 | pspnet_specs[version]["block_config"] if version is not None else block_config
43 | )
44 | self.n_classes = pspnet_specs[version]["n_classes"] if version is not None else n_classes
45 | self.input_size = pspnet_specs[version]["input_size"] if version is not None else input_size
46 |
47 | # Encoder
48 | self.convbnrelu1_1 = conv2DBatchNormRelu(
49 | in_channels=3, k_size=3, n_filters=64, padding=1, stride=2, bias=False
50 | )
51 | self.convbnrelu1_2 = conv2DBatchNormRelu(
52 | in_channels=64, k_size=3, n_filters=64, padding=1, stride=1, bias=False
53 | )
54 | self.convbnrelu1_3 = conv2DBatchNormRelu(
55 | in_channels=64, k_size=3, n_filters=128, padding=1, stride=1, bias=False
56 | )
57 |
58 | # Vanilla Residual Blocks
59 | self.res_block2 = residualBlockPSP(self.block_config[0], 128, 64, 256, 1, 1)
60 | self.res_block3 = residualBlockPSP(self.block_config[1], 256, 128, 512, 2, 1)
61 |
62 | # Dilated Residual Blocks
63 | self.res_block4 = residualBlockPSP(self.block_config[2], 512, 256, 1024, 1, 2)
64 | self.res_block5 = residualBlockPSP(self.block_config[3], 1024, 512, 2048, 1, 4)
65 |
66 | # Pyramid Pooling Module
67 | self.pyramid_pooling = pyramidPooling(2048, [6, 3, 2, 1])
68 |
69 | # Final conv layers
70 | self.cbr_final = conv2DBatchNormRelu(4096, 512, 3, 1, 1, False)
71 | self.dropout = nn.Dropout2d(p=0.1, inplace=False)
72 | self.classification = nn.Conv2d(512, self.n_classes, 1, 1, 0)
73 |
74 | # Auxiliary layers for training
75 | self.convbnrelu4_aux = conv2DBatchNormRelu(
76 | in_channels=1024, k_size=3, n_filters=256, padding=1, stride=1, bias=False
77 | )
78 | self.aux_cls = nn.Conv2d(256, self.n_classes, 1, 1, 0)
79 |
80 | # Define auxiliary loss function
81 | self.loss = multi_scale_cross_entropy2d
82 |
83 | def forward(self, x):
84 | inp_shape = x.shape[2:]
85 |
86 | # H, W -> H/2, W/2
87 | x = self.convbnrelu1_1(x)
88 | x = self.convbnrelu1_2(x)
89 | x = self.convbnrelu1_3(x)
90 |
91 | # H/2, W/2 -> H/4, W/4
92 | x = F.max_pool2d(x, 3, 2, 1)
93 |
94 | # H/4, W/4 -> H/8, W/8
95 | x = self.res_block2(x)
96 | x = self.res_block3(x)
97 | x = self.res_block4(x)
98 |
99 | # Auxiliary layers for training
100 | if self.training:
101 | x_aux = self.convbnrelu4_aux(x)
102 | x_aux = self.dropout(x_aux)
103 | x_aux = self.aux_cls(x_aux)
104 |
105 | x = self.res_block5(x)
106 |
107 | x = self.pyramid_pooling(x)
108 |
109 | x = self.cbr_final(x)
110 | x = self.dropout(x)
111 |
112 | x = self.classification(x)
113 | x = F.interpolate(x, size=inp_shape, mode="bilinear", align_corners=True)
114 |
115 | if self.training:
116 | return (x, x_aux)
117 | else: # eval mode
118 | return x
119 |
120 | def load_pretrained_model(self, model_path):
121 | """
122 | Load weights from caffemodel w/o caffe dependency
123 | and plug them in corresponding modules
124 | """
125 | # My eyes and my heart both hurt when writing this method
126 |
127 | # Only care about layer_types that have trainable parameters
128 | ltypes = ["BNData", "ConvolutionData", "HoleConvolutionData"]
129 |
130 | def _get_layer_params(layer, ltype):
131 |
132 | if ltype == "BNData":
133 | gamma = np.array(layer.blobs[0].data)
134 | beta = np.array(layer.blobs[1].data)
135 | mean = np.array(layer.blobs[2].data)
136 | var = np.array(layer.blobs[3].data)
137 | return [mean, var, gamma, beta]
138 |
139 | elif ltype in ["ConvolutionData", "HoleConvolutionData"]:
140 | is_bias = layer.convolution_param.bias_term
141 | weights = np.array(layer.blobs[0].data)
142 | bias = []
143 | if is_bias:
144 | bias = np.array(layer.blobs[1].data)
145 | return [weights, bias]
146 |
147 | elif ltype == "InnerProduct":
148 | raise Exception("Fully connected layers {}, not supported".format(ltype))
149 |
150 | else:
151 | raise Exception("Unkown layer type {}".format(ltype))
152 |
153 | net = caffe_pb2.NetParameter()
154 | with open(model_path, "rb") as model_file:
155 | net.MergeFromString(model_file.read())
156 |
157 | # dict formatted as -> key: :: value:
158 | layer_types = {}
159 | # dict formatted as -> key: :: value:[]
160 | layer_params = {}
161 |
162 | for l in net.layer:
163 | lname = l.name
164 | ltype = l.type
165 | if ltype in ltypes:
166 | print("Processing layer {}".format(lname))
167 | layer_types[lname] = ltype
168 | layer_params[lname] = _get_layer_params(l, ltype)
169 |
170 | # Set affine=False for all batchnorm modules
171 | def _no_affine_bn(module=None):
172 | if isinstance(module, nn.BatchNorm2d):
173 | module.affine = False
174 |
175 | if len([m for m in module.children()]) > 0:
176 | for child in module.children():
177 | _no_affine_bn(child)
178 |
179 | # _no_affine_bn(self)
180 |
181 | def _transfer_conv(layer_name, module):
182 | weights, bias = layer_params[layer_name]
183 | w_shape = np.array(module.weight.size())
184 |
185 | print(
186 | "CONV {}: Original {} and trans weights {}".format(
187 | layer_name, w_shape, weights.shape
188 | )
189 | )
190 |
191 | module.weight.data.copy_(torch.from_numpy(weights).view_as(module.weight))
192 |
193 | if len(bias) != 0:
194 | b_shape = np.array(module.bias.size())
195 | print(
196 | "CONV {}: Original {} and trans bias {}".format(layer_name, b_shape, bias.shape)
197 | )
198 | module.bias.data.copy_(torch.from_numpy(bias).view_as(module.bias))
199 |
200 | def _transfer_conv_bn(conv_layer_name, mother_module):
201 | conv_module = mother_module[0]
202 | bn_module = mother_module[1]
203 |
204 | _transfer_conv(conv_layer_name, conv_module)
205 |
206 | mean, var, gamma, beta = layer_params[conv_layer_name + "/bn"]
207 | print(
208 | "BN {}: Original {} and trans weights {}".format(
209 | conv_layer_name, bn_module.running_mean.size(), mean.shape
210 | )
211 | )
212 | bn_module.running_mean.copy_(torch.from_numpy(mean).view_as(bn_module.running_mean))
213 | bn_module.running_var.copy_(torch.from_numpy(var).view_as(bn_module.running_var))
214 | bn_module.weight.data.copy_(torch.from_numpy(gamma).view_as(bn_module.weight))
215 | bn_module.bias.data.copy_(torch.from_numpy(beta).view_as(bn_module.bias))
216 |
217 | def _transfer_residual(prefix, block):
218 | block_module, n_layers = block[0], block[1]
219 |
220 | bottleneck = block_module.layers[0]
221 | bottleneck_conv_bn_dic = {
222 | prefix + "_1_1x1_reduce": bottleneck.cbr1.cbr_unit,
223 | prefix + "_1_3x3": bottleneck.cbr2.cbr_unit,
224 | prefix + "_1_1x1_proj": bottleneck.cb4.cb_unit,
225 | prefix + "_1_1x1_increase": bottleneck.cb3.cb_unit,
226 | }
227 |
228 | for k, v in bottleneck_conv_bn_dic.items():
229 | _transfer_conv_bn(k, v)
230 |
231 | for layer_idx in range(2, n_layers + 1):
232 | residual_layer = block_module.layers[layer_idx - 1]
233 | residual_conv_bn_dic = {
234 | "_".join(
235 | map(str, [prefix, layer_idx, "1x1_reduce"])
236 | ): residual_layer.cbr1.cbr_unit,
237 | "_".join(map(str, [prefix, layer_idx, "3x3"])): residual_layer.cbr2.cbr_unit,
238 | "_".join(
239 | map(str, [prefix, layer_idx, "1x1_increase"])
240 | ): residual_layer.cb3.cb_unit,
241 | }
242 |
243 | for k, v in residual_conv_bn_dic.items():
244 | _transfer_conv_bn(k, v)
245 |
246 | convbn_layer_mapping = {
247 | "conv1_1_3x3_s2": self.convbnrelu1_1.cbr_unit,
248 | "conv1_2_3x3": self.convbnrelu1_2.cbr_unit,
249 | "conv1_3_3x3": self.convbnrelu1_3.cbr_unit,
250 | "conv5_3_pool6_conv": self.pyramid_pooling.paths[0].cbr_unit,
251 | "conv5_3_pool3_conv": self.pyramid_pooling.paths[1].cbr_unit,
252 | "conv5_3_pool2_conv": self.pyramid_pooling.paths[2].cbr_unit,
253 | "conv5_3_pool1_conv": self.pyramid_pooling.paths[3].cbr_unit,
254 | "conv5_4": self.cbr_final.cbr_unit,
255 | "conv4_" + str(self.block_config[2] + 1): self.convbnrelu4_aux.cbr_unit,
256 | } # Auxiliary layers for training
257 |
258 | residual_layers = {
259 | "conv2": [self.res_block2, self.block_config[0]],
260 | "conv3": [self.res_block3, self.block_config[1]],
261 | "conv4": [self.res_block4, self.block_config[2]],
262 | "conv5": [self.res_block5, self.block_config[3]],
263 | }
264 |
265 | # Transfer weights for all non-residual conv+bn layers
266 | for k, v in convbn_layer_mapping.items():
267 | _transfer_conv_bn(k, v)
268 |
269 | # Transfer weights for final non-bn conv layer
270 | _transfer_conv("conv6", self.classification)
271 | _transfer_conv("conv6_1", self.aux_cls)
272 |
273 | # Transfer weights for all residual layers
274 | for k, v in residual_layers.items():
275 | _transfer_residual(k, v)
276 |
277 | def tile_predict(self, imgs, include_flip_mode=True):
278 | """
279 | Predict by takin overlapping tiles from the image.
280 |
281 | Strides are adaptively computed from the imgs shape
282 | and input size
283 |
284 | :param imgs: torch.Tensor with shape [N, C, H, W] in BGR format
285 | :param side: int with side length of model input
286 | :param n_classes: int with number of classes in seg output.
287 | """
288 |
289 | side_x, side_y = self.input_size
290 | n_classes = self.n_classes
291 | n_samples, c, h, w = imgs.shape
292 | # n = int(max(h,w) / float(side) + 1)
293 | n_x = int(h / float(side_x) + 1)
294 | n_y = int(w / float(side_y) + 1)
295 | stride_x = (h - side_x) / float(n_x)
296 | stride_y = (w - side_y) / float(n_y)
297 |
298 | x_ends = [[int(i * stride_x), int(i * stride_x) + side_x] for i in range(n_x + 1)]
299 | y_ends = [[int(i * stride_y), int(i * stride_y) + side_y] for i in range(n_y + 1)]
300 |
301 | pred = np.zeros([n_samples, n_classes, h, w])
302 | count = np.zeros([h, w])
303 |
304 | slice_count = 0
305 | for sx, ex in x_ends:
306 | for sy, ey in y_ends:
307 | slice_count += 1
308 |
309 | imgs_slice = imgs[:, :, sx:ex, sy:ey]
310 | if include_flip_mode:
311 | imgs_slice_flip = torch.from_numpy(
312 | np.copy(imgs_slice.cpu().numpy()[:, :, :, ::-1])
313 | ).float()
314 |
315 | is_model_on_cuda = next(self.parameters()).is_cuda
316 |
317 | inp = Variable(imgs_slice, volatile=True)
318 | if include_flip_mode:
319 | flp = Variable(imgs_slice_flip, volatile=True)
320 |
321 | if is_model_on_cuda:
322 | inp = inp.cuda()
323 | if include_flip_mode:
324 | flp = flp.cuda()
325 |
326 | psub1 = F.softmax(self.forward(inp), dim=1).data.cpu().numpy()
327 | if include_flip_mode:
328 | psub2 = F.softmax(self.forward(flp), dim=1).data.cpu().numpy()
329 | psub = (psub1 + psub2[:, :, :, ::-1]) / 2.0
330 | else:
331 | psub = psub1
332 |
333 | pred[:, :, sx:ex, sy:ey] = psub
334 | count[sx:ex, sy:ey] += 1.0
335 |
336 | score = (pred / count[None, None, ...]).astype(np.float32)
337 | return score / np.expand_dims(score.sum(axis=1), axis=1)
338 |
339 |
340 | # For Testing Purposes only
341 | if __name__ == "__main__":
342 | cd = 0
343 | import os
344 | import scipy.misc as m
345 | from ptsemseg.loader.cityscapes_loader import cityscapesLoader as cl
346 |
347 | psp = pspnet(version="cityscapes")
348 |
349 | # Just need to do this one time
350 | caffemodel_dir_path = "PATH_TO_PSPNET_DIR/evaluation/model"
351 | psp.load_pretrained_model(
352 | model_path=os.path.join(caffemodel_dir_path, "pspnet101_cityscapes.caffemodel")
353 | )
354 | # psp.load_pretrained_model(model_path=os.path.join(caffemodel_dir_path,
355 | # 'pspnet50_ADE20K.caffemodel'))
356 | # psp.load_pretrained_model(model_path=os.path.join(caffemodel_dir_path,
357 | # 'pspnet101_VOC2012.caffemodel'))
358 | #
359 | # psp.load_state_dict(torch.load('psp.pth'))
360 |
361 | psp.float()
362 | psp.cuda(cd)
363 | psp.eval()
364 |
365 | dataset_root_dir = "PATH_TO_CITYSCAPES_DIR"
366 | dst = cl(root=dataset_root_dir)
367 | img = m.imread(
368 | os.path.join(
369 | dataset_root_dir,
370 | "leftImg8bit/demoVideo/stuttgart_00/stuttgart_00_000000_000010_leftImg8bit.png",
371 | )
372 | )
373 | m.imsave("cropped.png", img)
374 | orig_size = img.shape[:-1]
375 | img = img.transpose(2, 0, 1)
376 | img = img.astype(np.float64)
377 | img -= np.array([123.68, 116.779, 103.939])[:, None, None]
378 | img = np.copy(img[::-1, :, :])
379 | img = torch.from_numpy(img).float() # convert to torch tensor
380 | img = img.unsqueeze(0)
381 |
382 | out = psp.tile_predict(img)
383 | pred = np.argmax(out, axis=1)[0]
384 | decoded = dst.decode_segmap(pred)
385 | m.imsave("cityscapes_sttutgart_tiled.png", decoded)
386 | # m.imsave('cityscapes_sttutgart_tiled.png', pred)
387 |
388 | checkpoints_dir_path = "checkpoints"
389 | if not os.path.exists(checkpoints_dir_path):
390 | os.mkdir(checkpoints_dir_path)
391 | psp = torch.nn.DataParallel(
392 | psp, device_ids=range(torch.cuda.device_count())
393 | ) # append `module.`
394 | state = {"model_state": psp.state_dict()}
395 | torch.save(state, os.path.join(checkpoints_dir_path, "pspnet_101_cityscapes.pth"))
396 | # torch.save(state, os.path.join(checkpoints_dir_path, "pspnet_50_ade20k.pth"))
397 | # torch.save(state, os.path.join(checkpoints_dir_path, "pspnet_101_pascalvoc.pth"))
398 | print("Output Shape {} \t Input Shape {}".format(out.shape, img.shape))
399 |
--------------------------------------------------------------------------------
/ptsemseg/models/refinenet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class refinenet(nn.Module):
5 | """
6 | RefineNet: Multi-Path Refinement Networks for High-Resolution Semantic Segmentation
7 | URL: https://arxiv.org/abs/1611.06612
8 |
9 | References:
10 | 1) Original Author's MATLAB code: https://github.com/guosheng/refinenet
11 | 2) TF implementation by @eragonruan: https://github.com/eragonruan/refinenet-image-segmentation
12 | """
13 |
14 | def __init__(self, n_classes=21):
15 | super(refinenet, self).__init__()
16 | self.n_classes = n_classes
17 |
18 | def forward(self, x):
19 | pass
20 |
--------------------------------------------------------------------------------
/ptsemseg/models/segnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from ptsemseg.models.utils import segnetDown2, segnetDown3, segnetUp2, segnetUp3
4 |
5 |
6 | class segnet(nn.Module):
7 | def __init__(self, n_classes=21, in_channels=3, is_unpooling=True):
8 | super(segnet, self).__init__()
9 |
10 | self.in_channels = in_channels
11 | self.is_unpooling = is_unpooling
12 |
13 | self.down1 = segnetDown2(self.in_channels, 64)
14 | self.down2 = segnetDown2(64, 128)
15 | self.down3 = segnetDown3(128, 256)
16 | self.down4 = segnetDown3(256, 512)
17 | self.down5 = segnetDown3(512, 512)
18 |
19 | self.up5 = segnetUp3(512, 512)
20 | self.up4 = segnetUp3(512, 256)
21 | self.up3 = segnetUp3(256, 128)
22 | self.up2 = segnetUp2(128, 64)
23 | self.up1 = segnetUp2(64, n_classes)
24 |
25 | def forward(self, inputs):
26 |
27 | down1, indices_1, unpool_shape1 = self.down1(inputs)
28 | down2, indices_2, unpool_shape2 = self.down2(down1)
29 | down3, indices_3, unpool_shape3 = self.down3(down2)
30 | down4, indices_4, unpool_shape4 = self.down4(down3)
31 | down5, indices_5, unpool_shape5 = self.down5(down4)
32 |
33 | up5 = self.up5(down5, indices_5, unpool_shape5)
34 | up4 = self.up4(up5, indices_4, unpool_shape4)
35 | up3 = self.up3(up4, indices_3, unpool_shape3)
36 | up2 = self.up2(up3, indices_2, unpool_shape2)
37 | up1 = self.up1(up2, indices_1, unpool_shape1)
38 |
39 | return up1
40 |
41 | def init_vgg16_params(self, vgg16):
42 | blocks = [self.down1, self.down2, self.down3, self.down4, self.down5]
43 |
44 | features = list(vgg16.features.children())
45 |
46 | vgg_layers = []
47 | for _layer in features:
48 | if isinstance(_layer, nn.Conv2d):
49 | vgg_layers.append(_layer)
50 |
51 | merged_layers = []
52 | for idx, conv_block in enumerate(blocks):
53 | if idx < 2:
54 | units = [conv_block.conv1.cbr_unit, conv_block.conv2.cbr_unit]
55 | else:
56 | units = [
57 | conv_block.conv1.cbr_unit,
58 | conv_block.conv2.cbr_unit,
59 | conv_block.conv3.cbr_unit,
60 | ]
61 | for _unit in units:
62 | for _layer in _unit:
63 | if isinstance(_layer, nn.Conv2d):
64 | merged_layers.append(_layer)
65 |
66 | assert len(vgg_layers) == len(merged_layers)
67 |
68 | for l1, l2 in zip(vgg_layers, merged_layers):
69 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
70 | assert l1.weight.size() == l2.weight.size()
71 | assert l1.bias.size() == l2.bias.size()
72 | l2.weight.data = l1.weight.data
73 | l2.bias.data = l1.bias.data
74 |
--------------------------------------------------------------------------------
/ptsemseg/models/unet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from ptsemseg.models.utils import unetConv2, unetUp
4 |
5 |
6 | class unet(nn.Module):
7 | def __init__(
8 | self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True
9 | ):
10 | super(unet, self).__init__()
11 | self.is_deconv = is_deconv
12 | self.in_channels = in_channels
13 | self.is_batchnorm = is_batchnorm
14 | self.feature_scale = feature_scale
15 |
16 | filters = [64, 128, 256, 512, 1024]
17 | filters = [int(x / self.feature_scale) for x in filters]
18 |
19 | # downsampling
20 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
21 | self.maxpool1 = nn.MaxPool2d(kernel_size=2)
22 |
23 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
24 | self.maxpool2 = nn.MaxPool2d(kernel_size=2)
25 |
26 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
27 | self.maxpool3 = nn.MaxPool2d(kernel_size=2)
28 |
29 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
30 | self.maxpool4 = nn.MaxPool2d(kernel_size=2)
31 |
32 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm)
33 |
34 | # upsampling
35 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv)
36 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv)
37 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv)
38 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv)
39 |
40 | # final conv (without any concat)
41 | self.final = nn.Conv2d(filters[0], n_classes, 1)
42 |
43 | def forward(self, inputs):
44 | conv1 = self.conv1(inputs)
45 | maxpool1 = self.maxpool1(conv1)
46 |
47 | conv2 = self.conv2(maxpool1)
48 | maxpool2 = self.maxpool2(conv2)
49 |
50 | conv3 = self.conv3(maxpool2)
51 | maxpool3 = self.maxpool3(conv3)
52 |
53 | conv4 = self.conv4(maxpool3)
54 | maxpool4 = self.maxpool4(conv4)
55 |
56 | center = self.center(maxpool4)
57 | up4 = self.up_concat4(conv4, center)
58 | up3 = self.up_concat3(conv3, up4)
59 | up2 = self.up_concat2(conv2, up3)
60 | up1 = self.up_concat1(conv1, up2)
61 |
62 | final = self.final(up1)
63 |
64 | return final
65 |
--------------------------------------------------------------------------------
/ptsemseg/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 |
6 | from torch.autograd import Variable
7 |
8 |
9 | class conv2DBatchNorm(nn.Module):
10 | def __init__(
11 | self,
12 | in_channels,
13 | n_filters,
14 | k_size,
15 | stride,
16 | padding,
17 | bias=True,
18 | dilation=1,
19 | is_batchnorm=True,
20 | ):
21 | super(conv2DBatchNorm, self).__init__()
22 |
23 | conv_mod = nn.Conv2d(
24 | int(in_channels),
25 | int(n_filters),
26 | kernel_size=k_size,
27 | padding=padding,
28 | stride=stride,
29 | bias=bias,
30 | dilation=dilation,
31 | )
32 |
33 | if is_batchnorm:
34 | self.cb_unit = nn.Sequential(conv_mod, nn.BatchNorm2d(int(n_filters)))
35 | else:
36 | self.cb_unit = nn.Sequential(conv_mod)
37 |
38 | def forward(self, inputs):
39 | outputs = self.cb_unit(inputs)
40 | return outputs
41 |
42 |
43 | class conv2DGroupNorm(nn.Module):
44 | def __init__(
45 | self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, n_groups=16
46 | ):
47 | super(conv2DGroupNorm, self).__init__()
48 |
49 | conv_mod = nn.Conv2d(
50 | int(in_channels),
51 | int(n_filters),
52 | kernel_size=k_size,
53 | padding=padding,
54 | stride=stride,
55 | bias=bias,
56 | dilation=dilation,
57 | )
58 |
59 | self.cg_unit = nn.Sequential(conv_mod, nn.GroupNorm(n_groups, int(n_filters)))
60 |
61 | def forward(self, inputs):
62 | outputs = self.cg_unit(inputs)
63 | return outputs
64 |
65 |
66 | class deconv2DBatchNorm(nn.Module):
67 | def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True):
68 | super(deconv2DBatchNorm, self).__init__()
69 |
70 | self.dcb_unit = nn.Sequential(
71 | nn.ConvTranspose2d(
72 | int(in_channels),
73 | int(n_filters),
74 | kernel_size=k_size,
75 | padding=padding,
76 | stride=stride,
77 | bias=bias,
78 | ),
79 | nn.BatchNorm2d(int(n_filters)),
80 | )
81 |
82 | def forward(self, inputs):
83 | outputs = self.dcb_unit(inputs)
84 | return outputs
85 |
86 |
87 | class conv2DBatchNormRelu(nn.Module):
88 | def __init__(
89 | self,
90 | in_channels,
91 | n_filters,
92 | k_size,
93 | stride,
94 | padding,
95 | bias=True,
96 | dilation=1,
97 | is_batchnorm=True,
98 | ):
99 | super(conv2DBatchNormRelu, self).__init__()
100 |
101 | conv_mod = nn.Conv2d(
102 | int(in_channels),
103 | int(n_filters),
104 | kernel_size=k_size,
105 | padding=padding,
106 | stride=stride,
107 | bias=bias,
108 | dilation=dilation,
109 | )
110 |
111 | if is_batchnorm:
112 | self.cbr_unit = nn.Sequential(
113 | conv_mod, nn.BatchNorm2d(int(n_filters)), nn.ReLU(inplace=True)
114 | )
115 | else:
116 | self.cbr_unit = nn.Sequential(conv_mod, nn.ReLU(inplace=True))
117 |
118 | def forward(self, inputs):
119 | outputs = self.cbr_unit(inputs)
120 | return outputs
121 |
122 |
123 | class conv2DGroupNormRelu(nn.Module):
124 | def __init__(
125 | self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, n_groups=16
126 | ):
127 | super(conv2DGroupNormRelu, self).__init__()
128 |
129 | conv_mod = nn.Conv2d(
130 | int(in_channels),
131 | int(n_filters),
132 | kernel_size=k_size,
133 | padding=padding,
134 | stride=stride,
135 | bias=bias,
136 | dilation=dilation,
137 | )
138 |
139 | self.cgr_unit = nn.Sequential(
140 | conv_mod, nn.GroupNorm(n_groups, int(n_filters)), nn.ReLU(inplace=True)
141 | )
142 |
143 | def forward(self, inputs):
144 | outputs = self.cgr_unit(inputs)
145 | return outputs
146 |
147 |
148 | class deconv2DBatchNormRelu(nn.Module):
149 | def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True):
150 | super(deconv2DBatchNormRelu, self).__init__()
151 |
152 | self.dcbr_unit = nn.Sequential(
153 | nn.ConvTranspose2d(
154 | int(in_channels),
155 | int(n_filters),
156 | kernel_size=k_size,
157 | padding=padding,
158 | stride=stride,
159 | bias=bias,
160 | ),
161 | nn.BatchNorm2d(int(n_filters)),
162 | nn.ReLU(inplace=True),
163 | )
164 |
165 | def forward(self, inputs):
166 | outputs = self.dcbr_unit(inputs)
167 | return outputs
168 |
169 |
170 | class unetConv2(nn.Module):
171 | def __init__(self, in_size, out_size, is_batchnorm):
172 | super(unetConv2, self).__init__()
173 |
174 | if is_batchnorm:
175 | self.conv1 = nn.Sequential(
176 | nn.Conv2d(in_size, out_size, 3, 1, 0), nn.BatchNorm2d(out_size), nn.ReLU()
177 | )
178 | self.conv2 = nn.Sequential(
179 | nn.Conv2d(out_size, out_size, 3, 1, 0), nn.BatchNorm2d(out_size), nn.ReLU()
180 | )
181 | else:
182 | self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 0), nn.ReLU())
183 | self.conv2 = nn.Sequential(nn.Conv2d(out_size, out_size, 3, 1, 0), nn.ReLU())
184 |
185 | def forward(self, inputs):
186 | outputs = self.conv1(inputs)
187 | outputs = self.conv2(outputs)
188 | return outputs
189 |
190 |
191 | class unetUp(nn.Module):
192 | def __init__(self, in_size, out_size, is_deconv):
193 | super(unetUp, self).__init__()
194 | self.conv = unetConv2(in_size, out_size, False)
195 | if is_deconv:
196 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
197 | else:
198 | self.up = nn.UpsamplingBilinear2d(scale_factor=2)
199 |
200 | def forward(self, inputs1, inputs2):
201 | outputs2 = self.up(inputs2)
202 | offset = outputs2.size()[2] - inputs1.size()[2]
203 | padding = 2 * [offset // 2, offset // 2]
204 | outputs1 = F.pad(inputs1, padding)
205 | return self.conv(torch.cat([outputs1, outputs2], 1))
206 |
207 |
208 | class segnetDown2(nn.Module):
209 | def __init__(self, in_size, out_size):
210 | super(segnetDown2, self).__init__()
211 | self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1)
212 | self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1)
213 | self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True)
214 |
215 | def forward(self, inputs):
216 | outputs = self.conv1(inputs)
217 | outputs = self.conv2(outputs)
218 | unpooled_shape = outputs.size()
219 | outputs, indices = self.maxpool_with_argmax(outputs)
220 | return outputs, indices, unpooled_shape
221 |
222 |
223 | class segnetDown3(nn.Module):
224 | def __init__(self, in_size, out_size):
225 | super(segnetDown3, self).__init__()
226 | self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1)
227 | self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1)
228 | self.conv3 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1)
229 | self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True)
230 |
231 | def forward(self, inputs):
232 | outputs = self.conv1(inputs)
233 | outputs = self.conv2(outputs)
234 | outputs = self.conv3(outputs)
235 | unpooled_shape = outputs.size()
236 | outputs, indices = self.maxpool_with_argmax(outputs)
237 | return outputs, indices, unpooled_shape
238 |
239 |
240 | class segnetUp2(nn.Module):
241 | def __init__(self, in_size, out_size):
242 | super(segnetUp2, self).__init__()
243 | self.unpool = nn.MaxUnpool2d(2, 2)
244 | self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1)
245 | self.conv2 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1)
246 |
247 | def forward(self, inputs, indices, output_shape):
248 | outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape)
249 | outputs = self.conv1(outputs)
250 | outputs = self.conv2(outputs)
251 | return outputs
252 |
253 |
254 | class segnetUp3(nn.Module):
255 | def __init__(self, in_size, out_size):
256 | super(segnetUp3, self).__init__()
257 | self.unpool = nn.MaxUnpool2d(2, 2)
258 | self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1)
259 | self.conv2 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1)
260 | self.conv3 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1)
261 |
262 | def forward(self, inputs, indices, output_shape):
263 | outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape)
264 | outputs = self.conv1(outputs)
265 | outputs = self.conv2(outputs)
266 | outputs = self.conv3(outputs)
267 | return outputs
268 |
269 |
270 | class residualBlock(nn.Module):
271 | expansion = 1
272 |
273 | def __init__(self, in_channels, n_filters, stride=1, downsample=None):
274 | super(residualBlock, self).__init__()
275 |
276 | self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, 1, bias=False)
277 | self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, bias=False)
278 | self.downsample = downsample
279 | self.stride = stride
280 | self.relu = nn.ReLU(inplace=True)
281 |
282 | def forward(self, x):
283 | residual = x
284 |
285 | out = self.convbnrelu1(x)
286 | out = self.convbn2(out)
287 |
288 | if self.downsample is not None:
289 | residual = self.downsample(x)
290 |
291 | out += residual
292 | out = self.relu(out)
293 | return out
294 |
295 |
296 | class residualBottleneck(nn.Module):
297 | expansion = 4
298 |
299 | def __init__(self, in_channels, n_filters, stride=1, downsample=None):
300 | super(residualBottleneck, self).__init__()
301 | self.convbn1 = nn.Conv2DBatchNorm(in_channels, n_filters, k_size=1, bias=False)
302 | self.convbn2 = nn.Conv2DBatchNorm(
303 | n_filters, n_filters, k_size=3, padding=1, stride=stride, bias=False
304 | )
305 | self.convbn3 = nn.Conv2DBatchNorm(n_filters, n_filters * 4, k_size=1, bias=False)
306 | self.relu = nn.ReLU(inplace=True)
307 | self.downsample = downsample
308 | self.stride = stride
309 |
310 | def forward(self, x):
311 | residual = x
312 |
313 | out = self.convbn1(x)
314 | out = self.convbn2(out)
315 | out = self.convbn3(out)
316 |
317 | if self.downsample is not None:
318 | residual = self.downsample(x)
319 |
320 | out += residual
321 | out = self.relu(out)
322 |
323 | return out
324 |
325 |
326 | class linknetUp(nn.Module):
327 | def __init__(self, in_channels, n_filters):
328 | super(linknetUp, self).__init__()
329 |
330 | # B, 2C, H, W -> B, C/2, H, W
331 | self.convbnrelu1 = conv2DBatchNormRelu(
332 | in_channels, n_filters / 2, k_size=1, stride=1, padding=1
333 | )
334 |
335 | # B, C/2, H, W -> B, C/2, H, W
336 | self.deconvbnrelu2 = nn.deconv2DBatchNormRelu(
337 | n_filters / 2, n_filters / 2, k_size=3, stride=2, padding=0
338 | )
339 |
340 | # B, C/2, H, W -> B, C, H, W
341 | self.convbnrelu3 = conv2DBatchNormRelu(
342 | n_filters / 2, n_filters, k_size=1, stride=1, padding=1
343 | )
344 |
345 | def forward(self, x):
346 | x = self.convbnrelu1(x)
347 | x = self.deconvbnrelu2(x)
348 | x = self.convbnrelu3(x)
349 | return x
350 |
351 |
352 | class FRRU(nn.Module):
353 | """
354 | Full Resolution Residual Unit for FRRN
355 | """
356 |
357 | def __init__(self, prev_channels, out_channels, scale, group_norm=False, n_groups=None):
358 | super(FRRU, self).__init__()
359 | self.scale = scale
360 | self.prev_channels = prev_channels
361 | self.out_channels = out_channels
362 | self.group_norm = group_norm
363 | self.n_groups = n_groups
364 |
365 | if self.group_norm:
366 | conv_unit = conv2DGroupNormRelu
367 | self.conv1 = conv_unit(
368 | prev_channels + 32,
369 | out_channels,
370 | k_size=3,
371 | stride=1,
372 | padding=1,
373 | bias=False,
374 | n_groups=self.n_groups,
375 | )
376 | self.conv2 = conv_unit(
377 | out_channels,
378 | out_channels,
379 | k_size=3,
380 | stride=1,
381 | padding=1,
382 | bias=False,
383 | n_groups=self.n_groups,
384 | )
385 |
386 | else:
387 | conv_unit = conv2DBatchNormRelu
388 | self.conv1 = conv_unit(
389 | prev_channels + 32, out_channels, k_size=3, stride=1, padding=1, bias=False
390 | )
391 | self.conv2 = conv_unit(
392 | out_channels, out_channels, k_size=3, stride=1, padding=1, bias=False
393 | )
394 |
395 | self.conv_res = nn.Conv2d(out_channels, 32, kernel_size=1, stride=1, padding=0)
396 |
397 | def forward(self, y, z):
398 | x = torch.cat([y, nn.MaxPool2d(self.scale, self.scale)(z)], dim=1)
399 | y_prime = self.conv1(x)
400 | y_prime = self.conv2(y_prime)
401 |
402 | x = self.conv_res(y_prime)
403 | upsample_size = torch.Size([_s * self.scale for _s in y_prime.shape[-2:]])
404 | x = F.upsample(x, size=upsample_size, mode="nearest")
405 | z_prime = z + x
406 |
407 | return y_prime, z_prime
408 |
409 |
410 | class RU(nn.Module):
411 | """
412 | Residual Unit for FRRN
413 | """
414 |
415 | def __init__(self, channels, kernel_size=3, strides=1, group_norm=False, n_groups=None):
416 | super(RU, self).__init__()
417 | self.group_norm = group_norm
418 | self.n_groups = n_groups
419 |
420 | if self.group_norm:
421 | self.conv1 = conv2DGroupNormRelu(
422 | channels,
423 | channels,
424 | k_size=kernel_size,
425 | stride=strides,
426 | padding=1,
427 | bias=False,
428 | n_groups=self.n_groups,
429 | )
430 | self.conv2 = conv2DGroupNorm(
431 | channels,
432 | channels,
433 | k_size=kernel_size,
434 | stride=strides,
435 | padding=1,
436 | bias=False,
437 | n_groups=self.n_groups,
438 | )
439 |
440 | else:
441 | self.conv1 = conv2DBatchNormRelu(
442 | channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False
443 | )
444 | self.conv2 = conv2DBatchNorm(
445 | channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False
446 | )
447 |
448 | def forward(self, x):
449 | incoming = x
450 | x = self.conv1(x)
451 | x = self.conv2(x)
452 | return x + incoming
453 |
454 |
455 | class residualConvUnit(nn.Module):
456 | def __init__(self, channels, kernel_size=3):
457 | super(residualConvUnit, self).__init__()
458 |
459 | self.residual_conv_unit = nn.Sequential(
460 | nn.ReLU(inplace=True),
461 | nn.Conv2d(channels, channels, kernel_size=kernel_size),
462 | nn.ReLU(inplace=True),
463 | nn.Conv2d(channels, channels, kernel_size=kernel_size),
464 | )
465 |
466 | def forward(self, x):
467 | input = x
468 | x = self.residual_conv_unit(x)
469 | return x + input
470 |
471 |
472 | class multiResolutionFusion(nn.Module):
473 | def __init__(self, channels, up_scale_high, up_scale_low, high_shape, low_shape):
474 | super(multiResolutionFusion, self).__init__()
475 |
476 | self.up_scale_high = up_scale_high
477 | self.up_scale_low = up_scale_low
478 |
479 | self.conv_high = nn.Conv2d(high_shape[1], channels, kernel_size=3)
480 |
481 | if low_shape is not None:
482 | self.conv_low = nn.Conv2d(low_shape[1], channels, kernel_size=3)
483 |
484 | def forward(self, x_high, x_low):
485 | high_upsampled = F.upsample(
486 | self.conv_high(x_high), scale_factor=self.up_scale_high, mode="bilinear"
487 | )
488 |
489 | if x_low is None:
490 | return high_upsampled
491 |
492 | low_upsampled = F.upsample(
493 | self.conv_low(x_low), scale_factor=self.up_scale_low, mode="bilinear"
494 | )
495 |
496 | return low_upsampled + high_upsampled
497 |
498 |
499 | class chainedResidualPooling(nn.Module):
500 | def __init__(self, channels, input_shape):
501 | super(chainedResidualPooling, self).__init__()
502 |
503 | self.chained_residual_pooling = nn.Sequential(
504 | nn.ReLU(inplace=True),
505 | nn.MaxPool2d(5, 1, 2),
506 | nn.Conv2d(input_shape[1], channels, kernel_size=3),
507 | )
508 |
509 | def forward(self, x):
510 | input = x
511 | x = self.chained_residual_pooling(x)
512 | return x + input
513 |
514 |
515 | class pyramidPooling(nn.Module):
516 | def __init__(
517 | self, in_channels, pool_sizes, model_name="pspnet", fusion_mode="cat", is_batchnorm=True
518 | ):
519 | super(pyramidPooling, self).__init__()
520 |
521 | bias = not is_batchnorm
522 |
523 | self.paths = []
524 | for i in range(len(pool_sizes)):
525 | self.paths.append(
526 | conv2DBatchNormRelu(
527 | in_channels,
528 | int(in_channels / len(pool_sizes)),
529 | 1,
530 | 1,
531 | 0,
532 | bias=bias,
533 | is_batchnorm=is_batchnorm,
534 | )
535 | )
536 |
537 | self.path_module_list = nn.ModuleList(self.paths)
538 | self.pool_sizes = pool_sizes
539 | self.model_name = model_name
540 | self.fusion_mode = fusion_mode
541 |
542 | def forward(self, x):
543 | h, w = x.shape[2:]
544 |
545 | if self.training or self.model_name != "icnet": # general settings or pspnet
546 | k_sizes = []
547 | strides = []
548 | for pool_size in self.pool_sizes:
549 | k_sizes.append((int(h / pool_size), int(w / pool_size)))
550 | strides.append((int(h / pool_size), int(w / pool_size)))
551 | else: # eval mode and icnet: pre-trained for 1025 x 2049
552 | k_sizes = [(8, 15), (13, 25), (17, 33), (33, 65)]
553 | strides = [(5, 10), (10, 20), (16, 32), (33, 65)]
554 |
555 | if self.fusion_mode == "cat": # pspnet: concat (including x)
556 | output_slices = [x]
557 |
558 | for i, (module, pool_size) in enumerate(zip(self.path_module_list, self.pool_sizes)):
559 | out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0)
560 | # out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size))
561 | if self.model_name != "icnet":
562 | out = module(out)
563 | out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True)
564 | output_slices.append(out)
565 |
566 | return torch.cat(output_slices, dim=1)
567 | else: # icnet: element-wise sum (including x)
568 | pp_sum = x
569 |
570 | for i, (module, pool_size) in enumerate(zip(self.path_module_list, self.pool_sizes)):
571 | out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0)
572 | # out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size))
573 | if self.model_name != "icnet":
574 | out = module(out)
575 | out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True)
576 | pp_sum = pp_sum + out
577 |
578 | return pp_sum
579 |
580 |
581 | class bottleNeckPSP(nn.Module):
582 | def __init__(
583 | self, in_channels, mid_channels, out_channels, stride, dilation=1, is_batchnorm=True
584 | ):
585 | super(bottleNeckPSP, self).__init__()
586 |
587 | bias = not is_batchnorm
588 |
589 | self.cbr1 = conv2DBatchNormRelu(
590 | in_channels, mid_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm
591 | )
592 | if dilation > 1:
593 | self.cbr2 = conv2DBatchNormRelu(
594 | mid_channels,
595 | mid_channels,
596 | 3,
597 | stride=stride,
598 | padding=dilation,
599 | bias=bias,
600 | dilation=dilation,
601 | is_batchnorm=is_batchnorm,
602 | )
603 | else:
604 | self.cbr2 = conv2DBatchNormRelu(
605 | mid_channels,
606 | mid_channels,
607 | 3,
608 | stride=stride,
609 | padding=1,
610 | bias=bias,
611 | dilation=1,
612 | is_batchnorm=is_batchnorm,
613 | )
614 | self.cb3 = conv2DBatchNorm(
615 | mid_channels, out_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm
616 | )
617 | self.cb4 = conv2DBatchNorm(
618 | in_channels,
619 | out_channels,
620 | 1,
621 | stride=stride,
622 | padding=0,
623 | bias=bias,
624 | is_batchnorm=is_batchnorm,
625 | )
626 |
627 | def forward(self, x):
628 | conv = self.cb3(self.cbr2(self.cbr1(x)))
629 | residual = self.cb4(x)
630 | return F.relu(conv + residual, inplace=True)
631 |
632 |
633 | class bottleNeckIdentifyPSP(nn.Module):
634 | def __init__(self, in_channels, mid_channels, stride, dilation=1, is_batchnorm=True):
635 | super(bottleNeckIdentifyPSP, self).__init__()
636 |
637 | bias = not is_batchnorm
638 |
639 | self.cbr1 = conv2DBatchNormRelu(
640 | in_channels, mid_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm
641 | )
642 | if dilation > 1:
643 | self.cbr2 = conv2DBatchNormRelu(
644 | mid_channels,
645 | mid_channels,
646 | 3,
647 | stride=1,
648 | padding=dilation,
649 | bias=bias,
650 | dilation=dilation,
651 | is_batchnorm=is_batchnorm,
652 | )
653 | else:
654 | self.cbr2 = conv2DBatchNormRelu(
655 | mid_channels,
656 | mid_channels,
657 | 3,
658 | stride=1,
659 | padding=1,
660 | bias=bias,
661 | dilation=1,
662 | is_batchnorm=is_batchnorm,
663 | )
664 | self.cb3 = conv2DBatchNorm(
665 | mid_channels, in_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm
666 | )
667 |
668 | def forward(self, x):
669 | residual = x
670 | x = self.cb3(self.cbr2(self.cbr1(x)))
671 | return F.relu(x + residual, inplace=True)
672 |
673 |
674 | class residualBlockPSP(nn.Module):
675 | def __init__(
676 | self,
677 | n_blocks,
678 | in_channels,
679 | mid_channels,
680 | out_channels,
681 | stride,
682 | dilation=1,
683 | include_range="all",
684 | is_batchnorm=True,
685 | ):
686 | super(residualBlockPSP, self).__init__()
687 |
688 | if dilation > 1:
689 | stride = 1
690 |
691 | # residualBlockPSP = convBlockPSP + identityBlockPSPs
692 | layers = []
693 | if include_range in ["all", "conv"]:
694 | layers.append(
695 | bottleNeckPSP(
696 | in_channels,
697 | mid_channels,
698 | out_channels,
699 | stride,
700 | dilation,
701 | is_batchnorm=is_batchnorm,
702 | )
703 | )
704 | if include_range in ["all", "identity"]:
705 | for i in range(n_blocks - 1):
706 | layers.append(
707 | bottleNeckIdentifyPSP(
708 | out_channels, mid_channels, stride, dilation, is_batchnorm=is_batchnorm
709 | )
710 | )
711 |
712 | self.layers = nn.Sequential(*layers)
713 |
714 | def forward(self, x):
715 | return self.layers(x)
716 |
717 |
718 | class cascadeFeatureFusion(nn.Module):
719 | def __init__(
720 | self, n_classes, low_in_channels, high_in_channels, out_channels, is_batchnorm=True
721 | ):
722 | super(cascadeFeatureFusion, self).__init__()
723 |
724 | bias = not is_batchnorm
725 |
726 | self.low_dilated_conv_bn = conv2DBatchNorm(
727 | low_in_channels,
728 | out_channels,
729 | 3,
730 | stride=1,
731 | padding=2,
732 | bias=bias,
733 | dilation=2,
734 | is_batchnorm=is_batchnorm,
735 | )
736 | self.low_classifier_conv = nn.Conv2d(
737 | int(low_in_channels),
738 | int(n_classes),
739 | kernel_size=1,
740 | padding=0,
741 | stride=1,
742 | bias=True,
743 | dilation=1,
744 | ) # Train only
745 | self.high_proj_conv_bn = conv2DBatchNorm(
746 | high_in_channels,
747 | out_channels,
748 | 1,
749 | stride=1,
750 | padding=0,
751 | bias=bias,
752 | is_batchnorm=is_batchnorm,
753 | )
754 |
755 | def forward(self, x_low, x_high):
756 | x_low_upsampled = F.interpolate(
757 | x_low, size=get_interp_size(x_low, z_factor=2), mode="bilinear", align_corners=True
758 | )
759 |
760 | low_cls = self.low_classifier_conv(x_low_upsampled)
761 |
762 | low_fm = self.low_dilated_conv_bn(x_low_upsampled)
763 | high_fm = self.high_proj_conv_bn(x_high)
764 | high_fused_fm = F.relu(low_fm + high_fm, inplace=True)
765 |
766 | return high_fused_fm, low_cls
767 |
768 |
769 | def get_interp_size(input, s_factor=1, z_factor=1): # for caffe
770 | ori_h, ori_w = input.shape[2:]
771 |
772 | # shrink (s_factor >= 1)
773 | ori_h = (ori_h - 1) / s_factor + 1
774 | ori_w = (ori_w - 1) / s_factor + 1
775 |
776 | # zoom (z_factor >= 1)
777 | ori_h = ori_h + (ori_h - 1) * (z_factor - 1)
778 | ori_w = ori_w + (ori_w - 1) * (z_factor - 1)
779 |
780 | resize_shape = (int(ori_h), int(ori_w))
781 | return resize_shape
782 |
783 |
784 | def interp(input, output_size, mode="bilinear"):
785 | n, c, ih, iw = input.shape
786 | oh, ow = output_size
787 |
788 | # normalize to [-1, 1]
789 | h = torch.arange(0, oh, dtype=torch.float, device=input.device) / (oh - 1) * 2 - 1
790 | w = torch.arange(0, ow, dtype=torch.float, device=input.device) / (ow - 1) * 2 - 1
791 |
792 | grid = torch.zeros(oh, ow, 2, dtype=torch.float, device=input.device)
793 | grid[:, :, 0] = w.unsqueeze(0).repeat(oh, 1)
794 | grid[:, :, 1] = h.unsqueeze(0).repeat(ow, 1).transpose(0, 1)
795 | grid = grid.unsqueeze(0).repeat(n, 1, 1, 1) # grid.shape: [n, oh, ow, 2]
796 | grid = Variable(grid)
797 | if input.is_cuda:
798 | grid = grid.cuda()
799 |
800 | return F.grid_sample(input, grid, mode=mode)
801 |
802 |
803 | def get_upsampling_weight(in_channels, out_channels, kernel_size):
804 | """Make a 2D bilinear kernel suitable for upsampling"""
805 | factor = (kernel_size + 1) // 2
806 | if kernel_size % 2 == 1:
807 | center = factor - 1
808 | else:
809 | center = factor - 0.5
810 | og = np.ogrid[:kernel_size, :kernel_size]
811 | filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
812 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64)
813 | weight[range(in_channels), range(out_channels), :, :] = filt
814 | return torch.from_numpy(weight).float()
815 |
--------------------------------------------------------------------------------
/ptsemseg/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from torch.optim import SGD, Adam, ASGD, Adamax, Adadelta, Adagrad, RMSprop
4 |
5 | logger = logging.getLogger("ptsemseg")
6 |
7 | key2opt = {
8 | "sgd": SGD,
9 | "adam": Adam,
10 | "asgd": ASGD,
11 | "adamax": Adamax,
12 | "adadelta": Adadelta,
13 | "adagrad": Adagrad,
14 | "rmsprop": RMSprop,
15 | }
16 |
17 |
18 | def get_optimizer(cfg):
19 | if cfg["training"]["optimizer"] is None:
20 | logger.info("Using SGD optimizer")
21 | return SGD
22 |
23 | else:
24 | opt_name = cfg["training"]["optimizer"]["name"]
25 | if opt_name not in key2opt:
26 | raise NotImplementedError("Optimizer {} not implemented".format(opt_name))
27 |
28 | logger.info("Using {} optimizer".format(opt_name))
29 | return key2opt[opt_name]
30 |
--------------------------------------------------------------------------------
/ptsemseg/schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from torch.optim.lr_scheduler import MultiStepLR, ExponentialLR, CosineAnnealingLR
4 |
5 | from ptsemseg.schedulers.schedulers import WarmUpLR, ConstantLR, PolynomialLR
6 |
7 | logger = logging.getLogger("ptsemseg")
8 |
9 | key2scheduler = {
10 | "constant_lr": ConstantLR,
11 | "poly_lr": PolynomialLR,
12 | "multi_step": MultiStepLR,
13 | "cosine_annealing": CosineAnnealingLR,
14 | "exp_lr": ExponentialLR,
15 | }
16 |
17 |
18 | def get_scheduler(optimizer, scheduler_dict):
19 | if scheduler_dict is None:
20 | logger.info("Using No LR Scheduling")
21 | return ConstantLR(optimizer)
22 |
23 | s_type = scheduler_dict["name"]
24 | scheduler_dict.pop("name")
25 |
26 | logging.info("Using {} scheduler with {} params".format(s_type, scheduler_dict))
27 |
28 | warmup_dict = {}
29 | if "warmup_iters" in scheduler_dict:
30 | # This can be done in a more pythonic way...
31 | warmup_dict["warmup_iters"] = scheduler_dict.get("warmup_iters", 100)
32 | warmup_dict["mode"] = scheduler_dict.get("warmup_mode", "linear")
33 | warmup_dict["gamma"] = scheduler_dict.get("warmup_factor", 0.2)
34 |
35 | logger.info(
36 | "Using Warmup with {} iters {} gamma and {} mode".format(
37 | warmup_dict["warmup_iters"], warmup_dict["gamma"], warmup_dict["mode"]
38 | )
39 | )
40 |
41 | scheduler_dict.pop("warmup_iters", None)
42 | scheduler_dict.pop("warmup_mode", None)
43 | scheduler_dict.pop("warmup_factor", None)
44 |
45 | base_scheduler = key2scheduler[s_type](optimizer, **scheduler_dict)
46 | return WarmUpLR(optimizer, base_scheduler, **warmup_dict)
47 |
48 | return key2scheduler[s_type](optimizer, **scheduler_dict)
49 |
--------------------------------------------------------------------------------
/ptsemseg/schedulers/schedulers.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 |
3 |
4 | class ConstantLR(_LRScheduler):
5 | def __init__(self, optimizer, last_epoch=-1):
6 | super(ConstantLR, self).__init__(optimizer, last_epoch)
7 |
8 | def get_lr(self):
9 | return [base_lr for base_lr in self.base_lrs]
10 |
11 |
12 | class PolynomialLR(_LRScheduler):
13 | def __init__(self, optimizer, max_iter, decay_iter=1, gamma=0.9, last_epoch=-1):
14 | self.decay_iter = decay_iter
15 | self.max_iter = max_iter
16 | self.gamma = gamma
17 | super(PolynomialLR, self).__init__(optimizer, last_epoch)
18 |
19 | def get_lr(self):
20 | if self.last_epoch % self.decay_iter or self.last_epoch % self.max_iter:
21 | return [base_lr for base_lr in self.base_lrs]
22 | else:
23 | factor = (1 - self.last_epoch / float(self.max_iter)) ** self.gamma
24 | return [base_lr * factor for base_lr in self.base_lrs]
25 |
26 |
27 | class WarmUpLR(_LRScheduler):
28 | def __init__(
29 | self, optimizer, scheduler, mode="linear", warmup_iters=100, gamma=0.2, last_epoch=-1
30 | ):
31 | self.mode = mode
32 | self.scheduler = scheduler
33 | self.warmup_iters = warmup_iters
34 | self.gamma = gamma
35 | super(WarmUpLR, self).__init__(optimizer, last_epoch)
36 |
37 | def get_lr(self):
38 | cold_lrs = self.scheduler.get_lr()
39 |
40 | if self.last_epoch < self.warmup_iters:
41 | if self.mode == "linear":
42 | alpha = self.last_epoch / float(self.warmup_iters)
43 | factor = self.gamma * (1 - alpha) + alpha
44 |
45 | elif self.mode == "constant":
46 | factor = self.gamma
47 | else:
48 | raise KeyError("WarmUp type {} not implemented".format(self.mode))
49 |
50 | return [factor * base_lr for base_lr in cold_lrs]
51 |
52 | return cold_lrs
53 |
--------------------------------------------------------------------------------
/ptsemseg/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Misc Utility functions
3 | """
4 | import os
5 | import logging
6 | import datetime
7 | import numpy as np
8 |
9 | from collections import OrderedDict
10 |
11 |
12 | def recursive_glob(rootdir=".", suffix=""):
13 | """Performs recursive glob with given suffix and rootdir
14 | :param rootdir is the root directory
15 | :param suffix is the suffix to be searched
16 | """
17 | return [
18 | os.path.join(looproot, filename)
19 | for looproot, _, filenames in os.walk(rootdir)
20 | for filename in filenames
21 | if filename.endswith(suffix)
22 | ]
23 |
24 |
25 | def alpha_blend(input_image, segmentation_mask, alpha=0.5):
26 | """Alpha Blending utility to overlay RGB masks on RBG images
27 | :param input_image is a np.ndarray with 3 channels
28 | :param segmentation_mask is a np.ndarray with 3 channels
29 | :param alpha is a float value
30 | """
31 | blended = np.zeros(input_image.size, dtype=np.float32)
32 | blended = input_image * alpha + segmentation_mask * (1 - alpha)
33 | return blended
34 |
35 |
36 | def convert_state_dict(state_dict):
37 | """Converts a state dict saved from a dataParallel module to normal
38 | module state_dict inplace
39 | :param state_dict is the loaded DataParallel model_state
40 | """
41 | if not next(iter(state_dict)).startswith("module."):
42 | return state_dict # abort if dict is not a DataParallel model_state
43 | new_state_dict = OrderedDict()
44 | for k, v in state_dict.items():
45 | name = k[7:] # remove `module.`
46 | new_state_dict[name] = v
47 | return new_state_dict
48 |
49 |
50 | def get_logger(logdir):
51 | logger = logging.getLogger("ptsemseg")
52 | ts = str(datetime.datetime.now()).split(".")[0].replace(" ", "_")
53 | ts = ts.replace(":", "_").replace("-", "_")
54 | file_path = os.path.join(logdir, "run_{}.log".format(ts))
55 | hdlr = logging.FileHandler(file_path)
56 | formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
57 | hdlr.setFormatter(formatter)
58 | logger.addHandler(hdlr)
59 | logger.setLevel(logging.INFO)
60 | return logger
61 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==2.0.0
2 | numpy==1.12.1
3 | scipy==0.19.0
4 | torch==0.4.1
5 | torchvision==0.2.0
6 | tqdm==4.11.2
7 | pydensecrf
8 | protobuf
9 | tensorboardX
10 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import numpy as np
5 | import scipy.misc as misc
6 |
7 |
8 | from ptsemseg.models import get_model
9 | from ptsemseg.loader import get_loader
10 | from ptsemseg.utils import convert_state_dict
11 |
12 | try:
13 | import pydensecrf.densecrf as dcrf
14 | except:
15 | print(
16 | "Failed to import pydensecrf,\
17 | CRF post-processing will not work"
18 | )
19 |
20 |
21 | def test(args):
22 |
23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24 |
25 | model_file_name = os.path.split(args.model_path)[1]
26 | model_name = model_file_name[: model_file_name.find("_")]
27 |
28 | # Setup image
29 | print("Read Input Image from : {}".format(args.img_path))
30 | img = misc.imread(args.img_path)
31 |
32 | data_loader = get_loader(args.dataset)
33 | loader = data_loader(root=None, is_transform=True, img_norm=args.img_norm, test_mode=True)
34 | n_classes = loader.n_classes
35 |
36 | resized_img = misc.imresize(img, (loader.img_size[0], loader.img_size[1]), interp="bicubic")
37 |
38 | orig_size = img.shape[:-1]
39 | if model_name in ["pspnet", "icnet", "icnetBN"]:
40 | # uint8 with RGB mode, resize width and height which are odd numbers
41 | img = misc.imresize(img, (orig_size[0] // 2 * 2 + 1, orig_size[1] // 2 * 2 + 1))
42 | else:
43 | img = misc.imresize(img, (loader.img_size[0], loader.img_size[1]))
44 |
45 | img = img[:, :, ::-1]
46 | img = img.astype(np.float64)
47 | img -= loader.mean
48 | if args.img_norm:
49 | img = img.astype(float) / 255.0
50 |
51 | # NHWC -> NCHW
52 | img = img.transpose(2, 0, 1)
53 | img = np.expand_dims(img, 0)
54 | img = torch.from_numpy(img).float()
55 |
56 | # Setup Model
57 | model_dict = {"arch": model_name}
58 | model = get_model(model_dict, n_classes, version=args.dataset)
59 | state = convert_state_dict(torch.load(args.model_path)["model_state"])
60 | model.load_state_dict(state)
61 | model.eval()
62 | model.to(device)
63 |
64 | images = img.to(device)
65 | outputs = model(images)
66 |
67 | if args.dcrf:
68 | unary = outputs.data.cpu().numpy()
69 | unary = np.squeeze(unary, 0)
70 | unary = -np.log(unary)
71 | unary = unary.transpose(2, 1, 0)
72 | w, h, c = unary.shape
73 | unary = unary.transpose(2, 0, 1).reshape(loader.n_classes, -1)
74 | unary = np.ascontiguousarray(unary)
75 |
76 | resized_img = np.ascontiguousarray(resized_img)
77 |
78 | d = dcrf.DenseCRF2D(w, h, loader.n_classes)
79 | d.setUnaryEnergy(unary)
80 | d.addPairwiseBilateral(sxy=5, srgb=3, rgbim=resized_img, compat=1)
81 |
82 | q = d.inference(50)
83 | mask = np.argmax(q, axis=0).reshape(w, h).transpose(1, 0)
84 | decoded_crf = loader.decode_segmap(np.array(mask, dtype=np.uint8))
85 | dcrf_path = args.out_path[:-4] + "_drf.png"
86 | misc.imsave(dcrf_path, decoded_crf)
87 | print("Dense CRF Processed Mask Saved at: {}".format(dcrf_path))
88 |
89 | pred = np.squeeze(outputs.data.max(1)[1].cpu().numpy(), axis=0)
90 | if model_name in ["pspnet", "icnet", "icnetBN"]:
91 | pred = pred.astype(np.float32)
92 | # float32 with F mode, resize back to orig_size
93 | pred = misc.imresize(pred, orig_size, "nearest", mode="F")
94 |
95 | decoded = loader.decode_segmap(pred)
96 | print("Classes found: ", np.unique(pred))
97 | misc.imsave(args.out_path, decoded)
98 | print("Segmentation Mask Saved at: {}".format(args.out_path))
99 |
100 |
101 | if __name__ == "__main__":
102 | parser = argparse.ArgumentParser(description="Params")
103 | parser.add_argument(
104 | "--model_path",
105 | nargs="?",
106 | type=str,
107 | default="fcn8s_pascal_1_26.pkl",
108 | help="Path to the saved model",
109 | )
110 | parser.add_argument(
111 | "--dataset",
112 | nargs="?",
113 | type=str,
114 | default="pascal",
115 | help="Dataset to use ['pascal, camvid, ade20k etc']",
116 | )
117 |
118 | parser.add_argument(
119 | "--img_norm",
120 | dest="img_norm",
121 | action="store_true",
122 | help="Enable input image scales normalization [0, 1] \
123 | | True by default",
124 | )
125 | parser.add_argument(
126 | "--no-img_norm",
127 | dest="img_norm",
128 | action="store_false",
129 | help="Disable input image scales normalization [0, 1] |\
130 | True by default",
131 | )
132 | parser.set_defaults(img_norm=True)
133 |
134 | parser.add_argument(
135 | "--dcrf",
136 | dest="dcrf",
137 | action="store_true",
138 | help="Enable DenseCRF based post-processing | \
139 | False by default",
140 | )
141 | parser.add_argument(
142 | "--no-dcrf",
143 | dest="dcrf",
144 | action="store_false",
145 | help="Disable DenseCRF based post-processing | \
146 | False by default",
147 | )
148 | parser.set_defaults(dcrf=False)
149 |
150 | parser.add_argument(
151 | "--img_path", nargs="?", type=str, default=None, help="Path of the input image"
152 | )
153 | parser.add_argument(
154 | "--out_path", nargs="?", type=str, default=None, help="Path of the output segmap"
155 | )
156 | args = parser.parse_args()
157 | test(args)
158 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import time
4 | import shutil
5 | import torch
6 | import random
7 | import argparse
8 | import numpy as np
9 |
10 | from torch.utils import data
11 | from tqdm import tqdm
12 |
13 | from ptsemseg.models import get_model
14 | from ptsemseg.loss import get_loss_function
15 | from ptsemseg.loader import get_loader
16 | from ptsemseg.utils import get_logger
17 | from ptsemseg.metrics import runningScore, averageMeter
18 | from ptsemseg.augmentations import get_composed_augmentations
19 | from ptsemseg.schedulers import get_scheduler
20 | from ptsemseg.optimizers import get_optimizer
21 |
22 | from tensorboardX import SummaryWriter
23 |
24 |
25 | def train(cfg, writer, logger):
26 |
27 | # Setup seeds
28 | torch.manual_seed(cfg.get("seed", 1337))
29 | torch.cuda.manual_seed(cfg.get("seed", 1337))
30 | np.random.seed(cfg.get("seed", 1337))
31 | random.seed(cfg.get("seed", 1337))
32 |
33 | # Setup device
34 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35 |
36 | # Setup Augmentations
37 | augmentations = cfg["training"].get("augmentations", None)
38 | data_aug = get_composed_augmentations(augmentations)
39 |
40 | # Setup Dataloader
41 | data_loader = get_loader(cfg["data"]["dataset"])
42 | data_path = cfg["data"]["path"]
43 |
44 | t_loader = data_loader(
45 | data_path,
46 | is_transform=True,
47 | split=cfg["data"]["train_split"],
48 | img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
49 | augmentations=data_aug,
50 | )
51 |
52 | v_loader = data_loader(
53 | data_path,
54 | is_transform=True,
55 | split=cfg["data"]["val_split"],
56 | img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
57 | )
58 |
59 | n_classes = t_loader.n_classes
60 | trainloader = data.DataLoader(
61 | t_loader,
62 | batch_size=cfg["training"]["batch_size"],
63 | num_workers=cfg["training"]["n_workers"],
64 | shuffle=True,
65 | )
66 |
67 | valloader = data.DataLoader(
68 | v_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"]
69 | )
70 |
71 | # Setup Metrics
72 | running_metrics_val = runningScore(n_classes)
73 |
74 | # Setup Model
75 | model = get_model(cfg["model"], n_classes).to(device)
76 |
77 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
78 |
79 | # Setup optimizer, lr_scheduler and loss function
80 | optimizer_cls = get_optimizer(cfg)
81 | optimizer_params = {k: v for k, v in cfg["training"]["optimizer"].items() if k != "name"}
82 |
83 | optimizer = optimizer_cls(model.parameters(), **optimizer_params)
84 | logger.info("Using optimizer {}".format(optimizer))
85 |
86 | scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])
87 |
88 | loss_fn = get_loss_function(cfg)
89 | logger.info("Using loss {}".format(loss_fn))
90 |
91 | start_iter = 0
92 | if cfg["training"]["resume"] is not None:
93 | if os.path.isfile(cfg["training"]["resume"]):
94 | logger.info(
95 | "Loading model and optimizer from checkpoint '{}'".format(cfg["training"]["resume"])
96 | )
97 | checkpoint = torch.load(cfg["training"]["resume"])
98 | model.load_state_dict(checkpoint["model_state"])
99 | optimizer.load_state_dict(checkpoint["optimizer_state"])
100 | scheduler.load_state_dict(checkpoint["scheduler_state"])
101 | start_iter = checkpoint["epoch"]
102 | logger.info(
103 | "Loaded checkpoint '{}' (iter {})".format(
104 | cfg["training"]["resume"], checkpoint["epoch"]
105 | )
106 | )
107 | else:
108 | logger.info("No checkpoint found at '{}'".format(cfg["training"]["resume"]))
109 |
110 | val_loss_meter = averageMeter()
111 | time_meter = averageMeter()
112 |
113 | best_iou = -100.0
114 | i = start_iter
115 | flag = True
116 |
117 | while i <= cfg["training"]["train_iters"] and flag:
118 | for (images, labels) in trainloader:
119 | i += 1
120 | start_ts = time.time()
121 | scheduler.step()
122 | model.train()
123 | images = images.to(device)
124 | labels = labels.to(device)
125 |
126 | optimizer.zero_grad()
127 | outputs = model(images)
128 |
129 | loss = loss_fn(input=outputs, target=labels)
130 |
131 | loss.backward()
132 | optimizer.step()
133 |
134 | time_meter.update(time.time() - start_ts)
135 |
136 | if (i + 1) % cfg["training"]["print_interval"] == 0:
137 | fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}"
138 | print_str = fmt_str.format(
139 | i + 1,
140 | cfg["training"]["train_iters"],
141 | loss.item(),
142 | time_meter.avg / cfg["training"]["batch_size"],
143 | )
144 |
145 | print(print_str)
146 | logger.info(print_str)
147 | writer.add_scalar("loss/train_loss", loss.item(), i + 1)
148 | time_meter.reset()
149 |
150 | if (i + 1) % cfg["training"]["val_interval"] == 0 or (i + 1) == cfg["training"][
151 | "train_iters"
152 | ]:
153 | model.eval()
154 | with torch.no_grad():
155 | for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
156 | images_val = images_val.to(device)
157 | labels_val = labels_val.to(device)
158 |
159 | outputs = model(images_val)
160 | val_loss = loss_fn(input=outputs, target=labels_val)
161 |
162 | pred = outputs.data.max(1)[1].cpu().numpy()
163 | gt = labels_val.data.cpu().numpy()
164 |
165 | running_metrics_val.update(gt, pred)
166 | val_loss_meter.update(val_loss.item())
167 |
168 | writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
169 | logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))
170 |
171 | score, class_iou = running_metrics_val.get_scores()
172 | for k, v in score.items():
173 | print(k, v)
174 | logger.info("{}: {}".format(k, v))
175 | writer.add_scalar("val_metrics/{}".format(k), v, i + 1)
176 |
177 | for k, v in class_iou.items():
178 | logger.info("{}: {}".format(k, v))
179 | writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)
180 |
181 | val_loss_meter.reset()
182 | running_metrics_val.reset()
183 |
184 | if score["Mean IoU : \t"] >= best_iou:
185 | best_iou = score["Mean IoU : \t"]
186 | state = {
187 | "epoch": i + 1,
188 | "model_state": model.state_dict(),
189 | "optimizer_state": optimizer.state_dict(),
190 | "scheduler_state": scheduler.state_dict(),
191 | "best_iou": best_iou,
192 | }
193 | save_path = os.path.join(
194 | writer.file_writer.get_logdir(),
195 | "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]),
196 | )
197 | torch.save(state, save_path)
198 |
199 | if (i + 1) == cfg["training"]["train_iters"]:
200 | flag = False
201 | break
202 |
203 |
204 | if __name__ == "__main__":
205 | parser = argparse.ArgumentParser(description="config")
206 | parser.add_argument(
207 | "--config",
208 | nargs="?",
209 | type=str,
210 | default="configs/fcn8s_pascal.yml",
211 | help="Configuration file to use",
212 | )
213 |
214 | args = parser.parse_args()
215 |
216 | with open(args.config) as fp:
217 | cfg = yaml.load(fp)
218 |
219 | run_id = random.randint(1, 100000)
220 | logdir = os.path.join("runs", os.path.basename(args.config)[:-4], str(run_id))
221 | writer = SummaryWriter(log_dir=logdir)
222 |
223 | print("RUNDIR: {}".format(logdir))
224 | shutil.copy(args.config, logdir)
225 |
226 | logger = get_logger(logdir)
227 | logger.info("Let the games begin")
228 |
229 | train(cfg, writer, logger)
230 |
--------------------------------------------------------------------------------
/validate.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import torch
3 | import argparse
4 | import timeit
5 | import numpy as np
6 |
7 | from torch.utils import data
8 |
9 |
10 | from ptsemseg.models import get_model
11 | from ptsemseg.loader import get_loader
12 | from ptsemseg.metrics import runningScore
13 | from ptsemseg.utils import convert_state_dict
14 |
15 | torch.backends.cudnn.benchmark = True
16 |
17 |
18 | def validate(cfg, args):
19 |
20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21 |
22 | # Setup Dataloader
23 | data_loader = get_loader(cfg["data"]["dataset"])
24 | data_path = cfg["data"]["path"]
25 |
26 | loader = data_loader(
27 | data_path,
28 | split=cfg["data"]["val_split"],
29 | is_transform=True,
30 | img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
31 | )
32 |
33 | n_classes = loader.n_classes
34 |
35 | valloader = data.DataLoader(loader, batch_size=cfg["training"]["batch_size"], num_workers=8)
36 | running_metrics = runningScore(n_classes)
37 |
38 | # Setup Model
39 |
40 | model = get_model(cfg["model"], n_classes).to(device)
41 | state = convert_state_dict(torch.load(args.model_path)["model_state"])
42 | model.load_state_dict(state)
43 | model.eval()
44 | model.to(device)
45 |
46 | for i, (images, labels) in enumerate(valloader):
47 | start_time = timeit.default_timer()
48 |
49 | images = images.to(device)
50 |
51 | if args.eval_flip:
52 | outputs = model(images)
53 |
54 | # Flip images in numpy (not support in tensor)
55 | outputs = outputs.data.cpu().numpy()
56 | flipped_images = np.copy(images.data.cpu().numpy()[:, :, :, ::-1])
57 | flipped_images = torch.from_numpy(flipped_images).float().to(device)
58 | outputs_flipped = model(flipped_images)
59 | outputs_flipped = outputs_flipped.data.cpu().numpy()
60 | outputs = (outputs + outputs_flipped[:, :, :, ::-1]) / 2.0
61 |
62 | pred = np.argmax(outputs, axis=1)
63 | else:
64 | outputs = model(images)
65 | pred = outputs.data.max(1)[1].cpu().numpy()
66 |
67 | gt = labels.numpy()
68 |
69 | if args.measure_time:
70 | elapsed_time = timeit.default_timer() - start_time
71 | print(
72 | "Inference time \
73 | (iter {0:5d}): {1:3.5f} fps".format(
74 | i + 1, pred.shape[0] / elapsed_time
75 | )
76 | )
77 | running_metrics.update(gt, pred)
78 |
79 | score, class_iou = running_metrics.get_scores()
80 |
81 | for k, v in score.items():
82 | print(k, v)
83 |
84 | for i in range(n_classes):
85 | print(i, class_iou[i])
86 |
87 |
88 | if __name__ == "__main__":
89 | parser = argparse.ArgumentParser(description="Hyperparams")
90 | parser.add_argument(
91 | "--config",
92 | nargs="?",
93 | type=str,
94 | default="configs/fcn8s_pascal.yml",
95 | help="Config file to be used",
96 | )
97 | parser.add_argument(
98 | "--model_path",
99 | nargs="?",
100 | type=str,
101 | default="fcn8s_pascal_1_26.pkl",
102 | help="Path to the saved model",
103 | )
104 | parser.add_argument(
105 | "--eval_flip",
106 | dest="eval_flip",
107 | action="store_true",
108 | help="Enable evaluation with flipped image |\
109 | True by default",
110 | )
111 | parser.add_argument(
112 | "--no-eval_flip",
113 | dest="eval_flip",
114 | action="store_false",
115 | help="Disable evaluation with flipped image |\
116 | True by default",
117 | )
118 | parser.set_defaults(eval_flip=True)
119 |
120 | parser.add_argument(
121 | "--measure_time",
122 | dest="measure_time",
123 | action="store_true",
124 | help="Enable evaluation with time (fps) measurement |\
125 | True by default",
126 | )
127 | parser.add_argument(
128 | "--no-measure_time",
129 | dest="measure_time",
130 | action="store_false",
131 | help="Disable evaluation with time (fps) measurement |\
132 | True by default",
133 | )
134 | parser.set_defaults(measure_time=True)
135 |
136 | args = parser.parse_args()
137 |
138 | with open(args.config) as fp:
139 | cfg = yaml.load(fp)
140 |
141 | validate(cfg, args)
142 |
--------------------------------------------------------------------------------