├── .gitignore
├── DeepLabv3.ipynb
├── LICENSE
├── README.md
├── datasets
└── dload.sh
├── init.py
├── models
├── __pycache__
│ ├── assp.cpython-36.pyc
│ ├── deeplabv3.cpython-36.pyc
│ └── resnet_50.cpython-36.pyc
├── assp.py
├── deeplabv3.py
└── resnet_50.py
├── results
├── CityScapes
│ └── README.md
└── pascal voc 2012
│ ├── README.md
│ ├── epoch_10.png
│ ├── epoch_10_seg.png
│ ├── epoch_20.png
│ └── epoch_20_seg.png
├── test.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.txt
2 | *.html
3 | *.swp
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # pyenv
80 | .python-version
81 |
82 | # celery beat schedule file
83 | celerybeat-schedule
84 |
85 | # SageMath parsed files
86 | *.sage.py
87 |
88 | # Environments
89 | .env
90 | .venv
91 | env/
92 | venv/
93 | ENV/
94 | env.bak/
95 | venv.bak/
96 |
97 | # Spyder project settings
98 | .spyderproject
99 | .spyproject
100 |
101 | # Rope project settings
102 | .ropeproject
103 |
104 | # mkdocs documentation
105 | /site
106 |
107 | # mypy
108 | .mypy_cache/
109 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Aviv Shamsian
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeepLabv3
2 |
3 | In this repository we reproduce the DeepLabv3 paper which can be found here: [Rethinking Atrous Convolutions](https://arxiv.org/pdf/1706.05587.pdf)
4 | The DeepLabv3 model expects the feature extracting architecture to be ResNet50 or ResNet101 so this repository will also contain the code of the ResNet50 and ResNet101 architecture.
5 | We will also release colab notebook and pretrained models.
6 |
7 | ## How to use
8 |
9 | 0. This repository comes in with a handy notebook which you can use with Colab.
10 | You can find a link to the notebook here: [
11 | DeepLabv3](https://github.com/AvivSham/DeepLabv3.ipynb)
12 | Open it in colab: [Open in Colab](https://colab.research.google.com/github/AvivSham/DeepLabv3/blob/master/DeepLabv3.ipynb)
13 |
14 | ---
15 |
16 |
17 | 0. Clone the repository and cd into it
18 | ```
19 | git clone https://github.com/AvivSham/DeepLabv3.git
20 | cd DeepLabv3/
21 | ```
22 |
23 | 1. Use this command to train the model
24 | ```
25 | python3 init.py --mode train -iptr path/to/train/input/set/ -lptr /path/to/label/set/ --cuda False -nc
26 | ```
27 |
28 | 2. Use this command to test the model
29 | ```
30 | python3 init.py --mode test -m /path/to/model.pth -i /path/to/image.png -nc
31 | ```
32 |
33 | 3. Use `--help` to get more commands
34 | ```
35 | python3 init.py --help
36 | ```
37 |
38 | ---
39 |
40 |
41 | 0. If you want to download the cityscapes dataset
42 | ```
43 | sh ./datasets/dload.sh cityscapes
44 | ```
45 |
46 | 1. If you want to download the PASCAL VOC 2012 datasets
47 | ```
48 | sh ./datasets/dload.sh pascal
49 | ```
50 |
51 | ## Results
52 |
53 | ### Pascal VOC 2012
54 |
55 | ### CityScapes
56 |
57 | ## References
58 | 1. [Rethinking Atrous Convolutions](https://arxiv.org/pdf/1706.05587.pdf)
59 | 2. [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/pdf/1802.02611.pdf)
60 |
61 | ## License
62 |
63 | The code in this repository is free to use and to modify with proper linkage back to this repository.
64 |
--------------------------------------------------------------------------------
/datasets/dload.sh:
--------------------------------------------------------------------------------
1 | if [ "$1" = "cityscapes" ]; then
2 | if [ "$2" = "" ]; then
3 | echo 'Invalid username / password'
4 | exit
5 | fi
6 |
7 | wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username='$2'&password='$3'&submit=Login' https://www.cityscapes-dataset.com/login/
8 | wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1
9 | wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3
10 | unzip -qq gtFine_trainvaltest.zip
11 | unzip -qq leftImg8bit_trainvaltest.zip
12 |
13 | elif [ "$1" = "pascal" ]; then
14 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar -O VOCtrainval.tar
15 | tar -xf VOCtrainval.tar
16 |
17 | else
18 | echo "Invalid Argument"
19 | fi
20 |
--------------------------------------------------------------------------------
/init.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import argparse
3 | from train import *
4 | from test import *
5 |
6 | if __name__ == '__main__':
7 | parser = argparse.ArgumentParser()
8 |
9 | parser.add_argument('-m',
10 | type=str,
11 | help='The path to the pretrained cscapes model')
12 |
13 | parser.add_argument('-i', '--image-path',
14 | type=str,
15 | help='The path to the image to perform semantic segmentation')
16 |
17 | parser.add_argument('-rh', '--resize-height',
18 | type=int,
19 | default=1024,
20 | help='The height for the resized image')
21 |
22 | parser.add_argument('-rw', '--resize-width',
23 | type=int,
24 | default=2048,
25 | help='The width for the resized image')
26 |
27 | parser.add_argument('-lr', '--learning-rate',
28 | type=float,
29 | default=1e-3,
30 | help='The learning rate')
31 |
32 | parser.add_argument('-bs', '--batch-size',
33 | type=int,
34 | default=2,
35 | help='The batch size')
36 |
37 | parser.add_argument('-wd', '--weight-decay',
38 | type=float,
39 | default=1e-4,
40 | help='The weight decay')
41 |
42 | parser.add_argument('-c', '--constant',
43 | type=float,
44 | default=1.02,
45 | help='The constant used for calculating the class weights')
46 |
47 | parser.add_argument('-e', '--epochs',
48 | type=int,
49 | default=100,
50 | help='The number of epochs')
51 |
52 | parser.add_argument('-nc', '--num-classes',
53 | type=int,
54 | required=True,
55 | help='The number of epochs')
56 |
57 | parser.add_argument('-se', '--save-every',
58 | type=int,
59 | default=10,
60 | help='The number of epochs after which to save a model')
61 |
62 | parser.add_argument('-iptr', '--input-path-train',
63 | type=str,
64 | help='The path to the input dataset')
65 |
66 | parser.add_argument('-lptr', '--label-path-train',
67 | type=str,
68 | help='The path to the label dataset')
69 |
70 | parser.add_argument('-ipv', '--input-path-val',
71 | type=str,
72 | help='The path to the input dataset')
73 |
74 | parser.add_argument('-lpv', '--label-path-val',
75 | type=str,
76 | help='The path to the label dataset')
77 |
78 | parser.add_argument('-iptt', '--input-path-test',
79 | type=str,
80 | help='The path to the input dataset')
81 |
82 | parser.add_argument('-lptt', '--label-path-test',
83 | type=str,
84 | help='The path to the label dataset')
85 |
86 | parser.add_argument('-pe', '--print-every',
87 | type=int,
88 | default=1,
89 | help='The number of epochs after which to print the training loss')
90 |
91 | parser.add_argument('-ee', '--eval-every',
92 | type=int,
93 | default=10,
94 | help='The number of epochs after which to print the validation loss')
95 |
96 | parser.add_argument('--cuda',
97 | type=bool,
98 | default=False,
99 | help='Whether to use cuda or not')
100 |
101 | parser.add_argument('--mode',
102 | choices=['train', 'test'],
103 | default='train',
104 | help='Whether to train or test')
105 |
106 | parser.add_argument('-dt', '--dtype',
107 | choices=['cityscapes', 'pascal'],
108 | default='pascal',
109 | help='specify the dataset you are using')
110 |
111 | parser.add_argument('--scheduler',
112 | type=bool,
113 | default=False,
114 | help='Whether to use scheduler or not')
115 |
116 | parser.add_argument('--save',
117 | type=bool,
118 | default=True,
119 | help='Save the segmented image when predicting')
120 |
121 | FLAGS, unparsed = parser.parse_known_args()
122 |
123 | FLAGS.cuda = torch.device('cuda:0' if torch.cuda.is_available() and FLAGS.cuda \
124 | else 'cpu')
125 |
126 | print ('[INFO]Arguments read successfully!')
127 |
128 | if FLAGS.mode.lower() == 'train':
129 | print ('[INFO]Train Mode.')
130 |
131 | if FLAGS.iptr == None or FLAGS.ipv == None:
132 | raise ('Error: Kindly provide the path to the dataset')
133 |
134 | train(FLAGS)
135 |
136 | elif FLAGS.mode.lower() == 'test':
137 | print ('[INFO]Predict Mode.')
138 | predict(FLAGS)
139 | else:
140 | raise RuntimeError('Unknown mode passed. \n Mode passed should be either \
141 | of "train" or "test"')
142 |
--------------------------------------------------------------------------------
/models/__pycache__/assp.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AvivSham/DeepLabv3/c718ec3f8190ca2fc45a52a121e9009ca8284e2f/models/__pycache__/assp.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/deeplabv3.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AvivSham/DeepLabv3/c718ec3f8190ca2fc45a52a121e9009ca8284e2f/models/__pycache__/deeplabv3.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet_50.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AvivSham/DeepLabv3/c718ec3f8190ca2fc45a52a121e9009ca8284e2f/models/__pycache__/resnet_50.cpython-36.pyc
--------------------------------------------------------------------------------
/models/assp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class ASSP(nn.Module):
6 | def __init__(self,in_channels,out_channels = 256):
7 | super(ASSP,self).__init__()
8 |
9 |
10 | self.relu = nn.ReLU(inplace=True)
11 |
12 | self.conv1 = nn.Conv2d(in_channels = in_channels,
13 | out_channels = out_channels,
14 | kernel_size = 1,
15 | padding = 0,
16 | dilation=1,
17 | bias=False)
18 |
19 | self.bn1 = nn.BatchNorm2d(out_channels)
20 |
21 | self.conv2 = nn.Conv2d(in_channels = in_channels,
22 | out_channels = out_channels,
23 | kernel_size = 3,
24 | stride=1,
25 | padding = 6,
26 | dilation = 6,
27 | bias=False)
28 |
29 | self.bn2 = nn.BatchNorm2d(out_channels)
30 |
31 | self.conv3 = nn.Conv2d(in_channels = in_channels,
32 | out_channels = out_channels,
33 | kernel_size = 3,
34 | stride=1,
35 | padding = 12,
36 | dilation = 12,
37 | bias=False)
38 |
39 | self.bn3 = nn.BatchNorm2d(out_channels)
40 |
41 | self.conv4 = nn.Conv2d(in_channels = in_channels,
42 | out_channels = out_channels,
43 | kernel_size = 3,
44 | stride=1,
45 | padding = 18,
46 | dilation = 18,
47 | bias=False)
48 |
49 | self.bn4 = nn.BatchNorm2d(out_channels)
50 |
51 | self.conv5 = nn.Conv2d(in_channels = in_channels,
52 | out_channels = out_channels,
53 | kernel_size = 1,
54 | stride=1,
55 | padding = 0,
56 | dilation=1,
57 | bias=False)
58 |
59 | self.bn5 = nn.BatchNorm2d(out_channels)
60 |
61 | self.convf = nn.Conv2d(in_channels = out_channels * 5,
62 | out_channels = out_channels,
63 | kernel_size = 1,
64 | stride=1,
65 | padding = 0,
66 | dilation=1,
67 | bias=False)
68 |
69 | self.bnf = nn.BatchNorm2d(out_channels)
70 |
71 | self.adapool = nn.AdaptiveAvgPool2d(1)
72 |
73 |
74 | def forward(self,x):
75 |
76 | x1 = self.conv1(x)
77 | x1 = self.bn1(x1)
78 | x1 = self.relu(x1)
79 |
80 | x2 = self.conv2(x)
81 | x2 = self.bn2(x2)
82 | x2 = self.relu(x2)
83 |
84 | x3 = self.conv3(x)
85 | x3 = self.bn3(x3)
86 | x3 = self.relu(x3)
87 |
88 | x4 = self.conv4(x)
89 | x4 = self.bn4(x4)
90 | x4 = self.relu(x4)
91 |
92 | x5 = self.adapool(x)
93 | x5 = self.conv5(x5)
94 | x5 = self.bn5(x5)
95 | x5 = self.relu(x5)
96 | x5 = F.interpolate(x5, size = tuple(x4.shape[-2:]), mode='bilinear')
97 |
98 | x = torch.cat((x1,x2,x3,x4,x5), dim = 1) #channels first
99 | x = self.convf(x)
100 | x = self.bnf(x)
101 | x = self.relu(x)
102 |
103 | return x
104 |
--------------------------------------------------------------------------------
/models/deeplabv3.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .assp import ASSP
6 | from .resnet_50 import ResNet_50
7 |
8 | class DeepLabv3(nn.Module):
9 |
10 | def __init__(self, nc):
11 |
12 | super(DeepLabv3, self).__init__()
13 |
14 | self.nc = nc
15 |
16 | self.resnet = ResNet_50()
17 |
18 | self.assp = ASSP(in_channels = 1024)
19 |
20 | self.conv = nn.Conv2d(in_channels = 256, out_channels = self.nc,
21 | kernel_size = 1, stride=1, padding=0)
22 |
23 | def forward(self,x):
24 | _, _, h, w = x.shape
25 | x = self.resnet(x)
26 | x = self.assp(x)
27 | x = self.conv(x)
28 | x = F.interpolate(x, size=(h, w), mode='bilinear') #scale_factor = 16, mode='bilinear')
29 | return x
30 |
--------------------------------------------------------------------------------
/models/resnet_50.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import models
4 |
5 | class ResNet_50 (nn.Module):
6 | def __init__(self, in_channels = 3, conv1_out = 64):
7 | super(ResNet_50,self).__init__()
8 |
9 | self.resnet_50 = models.resnet50(pretrained = True)
10 |
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self,x):
14 | x = self.relu(self.resnet_50.bn1(self.resnet_50.conv1(x)))
15 | x = self.resnet_50.maxpool(x)
16 | x = self.resnet_50.layer1(x)
17 | x = self.resnet_50.layer2(x)
18 | x = self.resnet_50.layer3(x)
19 |
20 | return x
21 |
--------------------------------------------------------------------------------
/results/CityScapes/README.md:
--------------------------------------------------------------------------------
1 | ## Results while training on the CityScapes Dataset
2 |
3 | Do note: This doesn't contain all the results produced during training on the dataset. Just some along the way, how it looks as the training
4 | proceeds.
5 |
6 | ## After 50 iteration with `batch_size=2`
7 |
8 | Input:
9 | 
10 |
11 | Activations:
12 | 
13 |
14 | ## After 100 iteration with `batch_size=2`
15 |
16 | Input:
17 | 
18 |
19 | Activations:
20 | 
21 |
22 | ## After 150 iteration with `batch_size=2`
23 |
24 | Input:
25 | 
26 |
27 | Activations:
28 | 
29 |
30 | ## After 200 iteration with `batch_size=2`
31 |
32 | Input:
33 | 
34 |
35 | Activations:
36 | 
37 |
38 | ## After 250 iteration with `batch_size=2`
39 |
40 | Input:
41 | 
42 |
43 | Activations:
44 | 
45 |
46 | ## After 1000 iteration with `batch_size=2`
47 |
48 | Input:
49 | 
50 |
51 | Activations:
52 | 
53 |
54 | ## After 1100 iteration with `batch_size=2`
55 |
56 | Input:
57 | 
58 |
59 | Activations:
60 | 
61 |
62 |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/results/pascal voc 2012/README.md:
--------------------------------------------------------------------------------
1 | ## Results while training on the Pascal VOC 2012 Dataset
2 |
3 | Do note: This doesn't contain all the results produced during training on the dataset. Just some along the way, how it looks as the training proceeds.
4 |
5 | ## After 10 epochs with `batch_size=16`
6 |
7 | Input:
8 |
9 | 
10 |
11 | Activation:
12 |
13 | 
14 |
15 | ## After 20 epochs with `batch_size=16`
16 |
17 | Input:
18 |
19 | 
20 |
21 | Activation:
22 |
23 | 
24 |
25 |
--------------------------------------------------------------------------------
/results/pascal voc 2012/epoch_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AvivSham/DeepLabv3/c718ec3f8190ca2fc45a52a121e9009ca8284e2f/results/pascal voc 2012/epoch_10.png
--------------------------------------------------------------------------------
/results/pascal voc 2012/epoch_10_seg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AvivSham/DeepLabv3/c718ec3f8190ca2fc45a52a121e9009ca8284e2f/results/pascal voc 2012/epoch_10_seg.png
--------------------------------------------------------------------------------
/results/pascal voc 2012/epoch_20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AvivSham/DeepLabv3/c718ec3f8190ca2fc45a52a121e9009ca8284e2f/results/pascal voc 2012/epoch_20.png
--------------------------------------------------------------------------------
/results/pascal voc 2012/epoch_20_seg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AvivSham/DeepLabv3/c718ec3f8190ca2fc45a52a121e9009ca8284e2f/results/pascal voc 2012/epoch_20_seg.png
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from utils import *
4 | from models.deeplabv3 import DeepLabv3
5 | import sys
6 | import os
7 | import time
8 | from tqdm import tqdm
9 | from PIL import Image
10 | import matplotlib.pyplot as plt
11 | import matplotlib.gridspec as gridspec
12 |
13 | def predict(FLAGS):
14 | # Check if the pretrained model is available
15 | if not FLAGS.m.endswith('.pth'):
16 | raise RuntimeError('Unknown file passed. Must end with .pth')
17 | if FLAGS.image_path is None or not os.path.exists(FLAGS.image_path):
18 | raise RuntimeError('An image file path must be passed')
19 |
20 | h = FLAGS.resize_height
21 | w = FLAGS.resize_width
22 |
23 | print ('[INFO]Loading Checkpoint...')
24 | checkpoint = torch.load(FLAGS.m, map_location='cpu')
25 | print ('[INFO]Checkpoint Loaded')
26 |
27 | # Assuming the dataset is camvid
28 | deeplabv3 = DeepLabv3(FLAGS.num_classes)
29 | deeplabv3.load_state_dict(checkpoint['model_state_dict'])
30 | print ('[INFO]Initiated model with pretraiend weights.')
31 |
32 | tmg_ = np.array(Image.open(FLAGS.image_path))
33 | tmg_ = cv2.resize(tmg_, (w, h), cv2.INTER_NEAREST)
34 | tmg = torch.tensor(tmg_).unsqueeze(0).float()
35 | tmg = tmg.transpose(2, 3).transpose(1, 2)
36 |
37 | print ('[INFO]Starting inference...')
38 | deeplabv3.eval()
39 | s = time.time()
40 | out1 = deeplabv3(tmg.float()).squeeze(0)
41 | o = time.time()
42 | deeplabv3.train()
43 | print ('[INFO]Inference complete!')
44 | print ('[INFO]Time taken: ', o - s)
45 |
46 | out2 = out1.squeeze(0).cpu().detach().numpy()
47 |
48 | b_ = out1.data.max(0)[1].cpu().detach().numpy()
49 |
50 | b = decode_segmap_cscapes(b_)
51 | print ('[INFO]Got segmented results!')
52 |
53 | plt.title('Input Image')
54 | plt.axis('off')
55 | plt.imshow(tmg_)
56 | plt.show()
57 |
58 | plt.title('Output Image')
59 | plt.axis('off')
60 | plt.imshow(b)
61 | plt.show()
62 |
63 | plt.figure(figsize=(10, 10))
64 | gs = gridspec.GridSpec(9, 4)
65 | gs.update(wspace=0.025, hspace=0.005)
66 |
67 | label = 0
68 | for ii in range(34):
69 | plt.subplot(gs[ii])
70 | plt.axis('off')
71 | plt.imshow(out2[label, :, :])
72 | label += 1
73 | plt.show()
74 |
75 | if FLAGS.save:
76 | cv2.imwrite('seg.png', b)
77 | print ('[INFO]Segmented image saved successfully!')
78 |
79 | print ('[INFO] Prediction complete successfully!')
80 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from utils import *
4 | from models.deeplabv3 import DeepLabv3
5 | import sys
6 | from tqdm import tqdm
7 |
8 | def train(FLAGS):
9 |
10 | # Defining the hyperparameters
11 | device = FLAGS.cuda
12 | batch_size = FLAGS.batch_size
13 | epochs = FLAGS.epochs
14 | lr = FLAGS.learning_rate
15 | print_every = FLAGS.print_every
16 | eval_every = FLAGS.eval_every
17 | save_every = FLAGS.save_every
18 | nc = FLAGS.num_classes
19 | wd = FLAGS.weight_decay
20 |
21 | ip = FLAGS.input_path_train
22 | lp = FLAGS.label_path_train
23 |
24 | ipv = FLAGS.input_path_val
25 | lpv = FLAGS.label_path_val
26 |
27 | H = FLAGS.resize_height
28 | W = FLAGS.resize_width
29 |
30 | dtype = FLAGS.dtype
31 | sched = FLAGS.scheduler
32 |
33 | if FLAGS.dtype == 'cityscapes':
34 | train_samples = len(glob.glob(ip + '/**/*.png', recursive=True))
35 | eval_samples = len(glob.glob(lp + '/**/*.png', recursive=True))
36 | elif FLAGS.dtype == 'pascal':
37 | train_samples = len(os.listdir(lp))
38 | eval_samples = len(os.listdir(lp))
39 |
40 | print ('[INFO]Defined all the hyperparameters successfully!')
41 |
42 | # Get the class weights
43 | #print ('[INFO]Starting to define the class weights...')
44 | #pipe = loader(ip, lp, batch_size='all')
45 | #class_weights = get_class_weights(pipe, nc)
46 | #print ('[INFO]Fetched all class weights successfully!')
47 |
48 | # Get an instance of the model
49 | model = DeepLabv3(nc)
50 | print ('[INFO]Model Instantiated!')
51 |
52 | # Move the model to cuda if available
53 | model.to(device)
54 |
55 | # Define the criterion and the optimizer
56 | #criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights).to(device))
57 | criterion = nn.CrossEntropyLoss()
58 | optimizer = torch.optim.Adam(model.parameters(),
59 | lr=lr,
60 | weight_decay=wd)
61 | print ('[INFO]Defined the loss function and the optimizer')
62 |
63 | # Training Loop starts
64 | print ('[INFO]Staring Training...')
65 | print ()
66 |
67 | train_losses = []
68 | eval_losses = []
69 |
70 | if dtype == 'cityscapes':
71 | pipe = loader_cscapes(ip, lp, batch_size, h = H, w = W)
72 | elif dtype == 'pascal':
73 | pipe = loader(ip, lp, batch_size, h = H, w = W)
74 | #eval_pipe = loader(ipv, lpv, batch_size)
75 |
76 | show_every = 250
77 |
78 | train_losses = []
79 | eval_losses = []
80 |
81 | bc_train = train_samples // batch_size
82 | bc_eval = eval_samples // batch_size
83 |
84 | if sched:
85 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: (1 - (epoch / epochs)) ** 0.9)
86 |
87 | for e in range(1, epochs+1):
88 |
89 | train_loss = 0
90 | print ('-'*15,'Epoch %d' % e, '-'*15)
91 |
92 | if sched:
93 | scheduler.step()
94 |
95 | model.train()
96 |
97 | for ii in tqdm(range(bc_train)):
98 | X_batch, mask_batch = next(pipe)
99 |
100 | X_batch, mask_batch = X_batch.to(device), mask_batch.to(device)
101 |
102 | optimizer.zero_grad()
103 |
104 | out = model(X_batch.float())
105 |
106 | loss = criterion(out, mask_batch.long())
107 | loss.backward()
108 | optimizer.step()
109 |
110 | train_loss += loss.item()
111 |
112 | if ii % show_every == 0:
113 | out5 = show_cscpaes(model, H, W)
114 | checkpoint = {
115 | 'epochs' : e,
116 | 'model_state_dict' : model.state_dict(),
117 | 'opt_state_dict' : optimizer.state_dict()
118 | }
119 | torch.save(checkpoint, './ckpt-dlabv3-{}-{:2f}.pth'.format(e, train_loss))
120 | print ('Model saved!')
121 |
122 | print ()
123 | train_losses.append(train_loss)
124 |
125 | if (e+1) % print_every == 0:
126 | print ('Epoch {}/{}...'.format(e, epochs),
127 | 'Loss {:6f}'.format(train_loss))
128 |
129 | if e % save_every == 0:
130 |
131 | show_pascal(model, training_path, all_tests[np.random.randint(0, len(all_tests))])
132 | checkpoint = {
133 | 'epochs' : e,
134 | 'state_dict' : model.state_dict()
135 | }
136 | torch.save(checkpoint, '/content/ckpt-enet-{}-{:2f}.pth'.format(e, train_loss))
137 | print ('Model saved!')
138 |
139 |
140 | # show(model, all_tests[np.random.randint(0, len(all_tests))])
141 | # show_pascal(model, training_path, all_tests[np.random.randint(0, len(all_tests))])
142 |
143 | print ('[INFO]Training Process complete!')
144 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import matplotlib.pyplot as plt
4 | import os
5 | from PIL import Image
6 | import torch
7 |
8 | def create_class_mask(img, color_map, is_normalized_img=True, is_normalized_map=False, show_masks=False):
9 | """
10 | Function to create C matrices from the segmented image, where each of the C matrices is for one class
11 | with all ones at the pixel positions where that class is present
12 |
13 | img = The segmented image
14 |
15 | color_map = A list with tuples that contains all the RGB values for each color that represents
16 | some class in that image
17 |
18 | is_normalized_img = Boolean - Whether the image is normalized or not
19 | If normalized, then the image is multiplied with 255
20 |
21 | is_normalized_map = Boolean - Represents whether the color map is normalized or not, if so
22 | then the color map values are multiplied with 255
23 |
24 | show_masks = Wherether to show the created masks or not
25 | """
26 |
27 | if is_normalized_img and (not is_normalized_map):
28 | img *= 255
29 |
30 | if is_normalized_map and (not is_normalized_img):
31 | img = img / 255
32 |
33 | mask = []
34 | hw_tuple = img.shape[:-1]
35 | for color in color_map:
36 | color_img = []
37 | for idx in range(3):
38 | color_img.append(np.ones(hw_tuple) * color[idx])
39 |
40 | color_img = np.array(color_img, dtype=np.uint8).transpose(1, 2, 0)
41 |
42 | mask.append(np.uint8((color_img == img).sum(axis = -1) == 3))
43 |
44 | return np.array(mask)
45 |
46 |
47 | # Cityscapes dataset Loader
48 |
49 | def loader_cscapes(input_path, segmented_path, batch_size, h=1024, w=2048, limited=False):
50 | filenames_t = sorted(glob.glob(input_path + '/**/*.png', recursive=True), key=lambda x : int(x.split('/')[-1].split('_')[1] + x.split('/')[-1].split('_')[2]))
51 | total_files_t = len(filenames_t)
52 |
53 | filenames_s = sorted(glob.glob(segmented_path + '/**/*labelIds.png', recursive=True), key=lambda x : int(x.split('/')[-1].split('_')[1] + x.split('/')[-1].split('_')[2]))
54 |
55 | total_files_s = len(filenames_s)
56 |
57 | assert(total_files_t == total_files_s)
58 |
59 | batches = np.random.permutation(np.arange(total_files_s))
60 | idx0 = 0
61 | idx1 = idx0 + batch_size
62 |
63 | if str(batch_size).lower() == 'all':
64 | batch_size = total_files_s
65 |
66 | idx = 1 if not limited else total_files_s // batch_size + 1
67 | while(idx):
68 |
69 | batch = np.arange(idx0, idx1)
70 |
71 | # Choosing random indexes of images and labels
72 | batch_idxs = np.random.randint(0, total_files_s, batch_size)
73 |
74 | inputs = []
75 | labels = []
76 |
77 | for jj in batch_idxs:
78 | # Reading normalized photo
79 | img = np.array(Image.open(filenames_t[jj]))
80 | # Resizing using nearest neighbor method
81 | inputs.append(img)
82 |
83 | # Reading semantic image
84 | img = Image.open(filenames_s[jj])
85 | img = np.array(img)
86 | # Resizing using nearest neighbor method
87 | labels.append(img)
88 |
89 | inputs = np.stack(inputs, axis=2)
90 | # Changing image format to C x H x W
91 | inputs = torch.tensor(inputs).transpose(0, 2).transpose(1, 3)
92 |
93 | labels = torch.tensor(labels)
94 |
95 | idx0 = idx1 if idx1 + batch_size < total_files_s else 0
96 | idx1 = idx0 + batch_size
97 |
98 | if limited:
99 | idx -= 1
100 |
101 | yield inputs, labels
102 |
103 | def loader(training_path, segmented_path, batch_size, h=512, w=512):
104 | """
105 | The Loader to generate inputs and labels from the Image and Segmented Directory
106 |
107 | Arguments:
108 |
109 | training_path - str - Path to the directory that contains the training images
110 |
111 | segmented_path - str - Path to the directory that contains the segmented images
112 |
113 | batch_size - int - the batch size
114 |
115 | yields inputs and labels of the batch size
116 | """
117 |
118 | filenames_t = os.listdir(training_path)
119 | total_files_t = len(filenames_t)
120 |
121 | filenames_s = os.listdir(segmented_path)
122 | total_files_s = len(filenames_s)
123 |
124 | assert(total_files_t == total_files_s)
125 |
126 | if str(batch_size).lower() == 'all':
127 | batch_size = total_files_s
128 |
129 | idx = 0
130 | while(1):
131 | batch_idxs = np.random.randint(0, total_files_s, batch_size)
132 |
133 | inputs = []
134 | labels = []
135 |
136 | for jj in batch_idxs:
137 | img = plt.imread(training_path + filenames_t[jj])
138 | img = cv2.resize(img, (h, w), cv2.INTER_NEAREST)
139 | inputs.append(img)
140 |
141 | img = Image.open(segmented_path + filenames_s[jj])
142 | img = np.array(img)
143 | img = cv2.resize(img, (h, w), cv2.INTER_NEAREST)
144 | labels.append(img)
145 |
146 | inputs = np.stack(inputs, axis=2)
147 | inputs = torch.tensor(inputs).transpose(0, 2).transpose(1, 3)
148 |
149 | labels = torch.tensor(labels)
150 |
151 | yield inputs, labels
152 |
153 |
154 | def decode_segmap_camvid(image):
155 | Sky = [128, 128, 128]
156 | Building = [128, 0, 0]
157 | Pole = [192, 192, 128]
158 | Road_marking = [255, 69, 0]
159 | Road = [128, 64, 128]
160 | Pavement = [60, 40, 222]
161 | Tree = [128, 128, 0]
162 | SignSymbol = [192, 128, 128]
163 | Fence = [64, 64, 128]
164 | Car = [64, 0, 128]
165 | Pedestrian = [64, 64, 0]
166 | Bicyclist = [0, 128, 192]
167 |
168 | label_colors = np.array([Sky, Building, Pole, Road_marking, Road,
169 | Pavement, Tree, SignSymbol, Fence, Car,
170 | Pedestrian, Bicyclist]).astype(np.uint8)
171 |
172 | r = np.zeros_like(image).astype(np.uint8)
173 | g = np.zeros_like(image).astype(np.uint8)
174 | b = np.zeros_like(image).astype(np.uint8)
175 |
176 | for label in range(len(label_colors)):
177 | r[image == label] = label_colors[label, 0]
178 | g[image == label] = label_colors[label, 1]
179 | b[image == label] = label_colors[label, 2]
180 |
181 | rgb = np.zeros((image.shape[0], image.shape[1], 3)).astype(np.uint8)
182 | rgb[:, :, 0] = r
183 | rgb[:, :, 1] = g
184 | rgb[:, :, 2] = b
185 |
186 | return rgb
187 |
188 | def decode_segmap_cscapes(image, nc=34):
189 |
190 | label_colours = np.array([(0, 0, 0), # 0=background
191 | (0, 0, 0), # 1=ego vehicle
192 | (0, 0, 0), # 2=rectification border
193 | (0, 0, 0), # 3=out of toi
194 | (0, 0, 0), # 4=static
195 | # 5=dynamic, 6=ground, 7=road, 8=sidewalk, 9=parking
196 | (111, 74, 0), ( 81, 0, 81), (128, 64,128), (244, 35,232), (250,170,160),
197 | # 10=rail track, 11=building, 12=wall, 13=fence, 14=guard rail
198 | (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153), (180,165,180),
199 | # 15=bridge, 16=tunnel, 17=pole, 18=pole group, 19=traffic light
200 | (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30),
201 | # 20=traffic sign, 21=vegetation, 22=terrain, 23=sky, 24=person
202 | (220,220, 0), (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60),
203 | # 25=rider, 26=car, 27=truck, 28=bus, 29=caravan,
204 | (255, 0, 0), ( 0, 0,142), ( 0, 0, 70), ( 0, 60,100), ( 0, 0, 90),
205 | # 30=trailer, 31=train, 32=motorcycle, 33=bicycle, 34=license plate,
206 | ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142),
207 | ])
208 |
209 | r = np.zeros_like(image).astype(np.uint8)
210 | g = np.zeros_like(image).astype(np.uint8)
211 | b = np.zeros_like(image).astype(np.uint8)
212 |
213 | for l in range(0, nc):
214 | r[image == l] = label_colours[l, 0]
215 | g[image == l] = label_colours[l, 1]
216 | b[image == l] = label_colours[l, 2]
217 |
218 | rgb = np.zeros((image.shape[0], image.shape[1], 3)).astype(np.uint8)
219 | rgb[:, :, 0] = b
220 | rgb[:, :, 1] = g
221 | rgb[:, :, 2] = r
222 | return rgb
223 |
224 | def show_images(images, in_row=True):
225 | '''
226 | Helper function to show 3 images
227 | '''
228 | total_images = len(images)
229 |
230 | rc_tuple = (1, total_images)
231 | if not in_row:
232 | rc_tuple = (total_images, 1)
233 |
234 | #figure = plt.figure(figsize=(20, 10))
235 | for ii in range(len(images)):
236 | plt.subplot(*rc_tuple, ii+1)
237 | plt.title(images[ii][0])
238 | plt.axis('off')
239 | plt.imshow(images[ii][1])
240 | plt.show()
241 |
242 | def get_class_weights(loader, num_classes, c=1.02):
243 | '''
244 | This class return the class weights for each class
245 |
246 | Arguments:
247 | - loader : The generator object which return all the labels at one iteration
248 | Do Note: That this class expects all the labels to be returned in
249 | one iteration
250 |
251 | - num_classes : The number of classes
252 |
253 | Return:
254 | - class_weights : An array equal in length to the number of classes
255 | containing the class weights for each class
256 | '''
257 |
258 | _, labels = next(loader)
259 | all_labels = labels.flatten()
260 | each_class = np.bincount(all_labels, minlength=num_classes)
261 | prospensity_score = each_class / len(all_labels)
262 | class_weights = 1 / (np.log(c + prospensity_score))
263 | return class_weights
264 |
--------------------------------------------------------------------------------