├── .gitignore ├── DeepLabv3.ipynb ├── LICENSE ├── README.md ├── datasets └── dload.sh ├── init.py ├── models ├── __pycache__ │ ├── assp.cpython-36.pyc │ ├── deeplabv3.cpython-36.pyc │ └── resnet_50.cpython-36.pyc ├── assp.py ├── deeplabv3.py └── resnet_50.py ├── results ├── CityScapes │ └── README.md └── pascal voc 2012 │ ├── README.md │ ├── epoch_10.png │ ├── epoch_10_seg.png │ ├── epoch_20.png │ └── epoch_20_seg.png ├── test.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.txt 2 | *.html 3 | *.swp 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Aviv Shamsian 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepLabv3 2 | 3 | In this repository we reproduce the DeepLabv3 paper which can be found here: [Rethinking Atrous Convolutions](https://arxiv.org/pdf/1706.05587.pdf) 4 | The DeepLabv3 model expects the feature extracting architecture to be ResNet50 or ResNet101 so this repository will also contain the code of the ResNet50 and ResNet101 architecture. 5 | We will also release colab notebook and pretrained models. 6 | 7 | ## How to use 8 | 9 | 0. This repository comes in with a handy notebook which you can use with Colab.
10 | You can find a link to the notebook here: [ 11 | DeepLabv3](https://github.com/AvivSham/DeepLabv3.ipynb)
12 | Open it in colab: [Open in Colab](https://colab.research.google.com/github/AvivSham/DeepLabv3/blob/master/DeepLabv3.ipynb) 13 | 14 | --- 15 | 16 | 17 | 0. Clone the repository and cd into it 18 | ``` 19 | git clone https://github.com/AvivSham/DeepLabv3.git 20 | cd DeepLabv3/ 21 | ``` 22 | 23 | 1. Use this command to train the model 24 | ``` 25 | python3 init.py --mode train -iptr path/to/train/input/set/ -lptr /path/to/label/set/ --cuda False -nc 26 | ``` 27 | 28 | 2. Use this command to test the model 29 | ``` 30 | python3 init.py --mode test -m /path/to/model.pth -i /path/to/image.png -nc 31 | ``` 32 | 33 | 3. Use `--help` to get more commands 34 | ``` 35 | python3 init.py --help 36 | ``` 37 | 38 | --- 39 | 40 | 41 | 0. If you want to download the cityscapes dataset 42 | ``` 43 | sh ./datasets/dload.sh cityscapes 44 | ``` 45 | 46 | 1. If you want to download the PASCAL VOC 2012 datasets 47 | ``` 48 | sh ./datasets/dload.sh pascal 49 | ``` 50 | 51 | ## Results 52 | 53 | ### Pascal VOC 2012 54 | 55 | ### CityScapes 56 | 57 | ## References 58 | 1. [Rethinking Atrous Convolutions](https://arxiv.org/pdf/1706.05587.pdf) 59 | 2. [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/pdf/1802.02611.pdf) 60 | 61 | ## License 62 | 63 | The code in this repository is free to use and to modify with proper linkage back to this repository. 64 | -------------------------------------------------------------------------------- /datasets/dload.sh: -------------------------------------------------------------------------------- 1 | if [ "$1" = "cityscapes" ]; then 2 | if [ "$2" = "" ]; then 3 | echo 'Invalid username / password' 4 | exit 5 | fi 6 | 7 | wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username='$2'&password='$3'&submit=Login' https://www.cityscapes-dataset.com/login/ 8 | wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1 9 | wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3 10 | unzip -qq gtFine_trainvaltest.zip 11 | unzip -qq leftImg8bit_trainvaltest.zip 12 | 13 | elif [ "$1" = "pascal" ]; then 14 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar -O VOCtrainval.tar 15 | tar -xf VOCtrainval.tar 16 | 17 | else 18 | echo "Invalid Argument" 19 | fi 20 | -------------------------------------------------------------------------------- /init.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | from train import * 4 | from test import * 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser() 8 | 9 | parser.add_argument('-m', 10 | type=str, 11 | help='The path to the pretrained cscapes model') 12 | 13 | parser.add_argument('-i', '--image-path', 14 | type=str, 15 | help='The path to the image to perform semantic segmentation') 16 | 17 | parser.add_argument('-rh', '--resize-height', 18 | type=int, 19 | default=1024, 20 | help='The height for the resized image') 21 | 22 | parser.add_argument('-rw', '--resize-width', 23 | type=int, 24 | default=2048, 25 | help='The width for the resized image') 26 | 27 | parser.add_argument('-lr', '--learning-rate', 28 | type=float, 29 | default=1e-3, 30 | help='The learning rate') 31 | 32 | parser.add_argument('-bs', '--batch-size', 33 | type=int, 34 | default=2, 35 | help='The batch size') 36 | 37 | parser.add_argument('-wd', '--weight-decay', 38 | type=float, 39 | default=1e-4, 40 | help='The weight decay') 41 | 42 | parser.add_argument('-c', '--constant', 43 | type=float, 44 | default=1.02, 45 | help='The constant used for calculating the class weights') 46 | 47 | parser.add_argument('-e', '--epochs', 48 | type=int, 49 | default=100, 50 | help='The number of epochs') 51 | 52 | parser.add_argument('-nc', '--num-classes', 53 | type=int, 54 | required=True, 55 | help='The number of epochs') 56 | 57 | parser.add_argument('-se', '--save-every', 58 | type=int, 59 | default=10, 60 | help='The number of epochs after which to save a model') 61 | 62 | parser.add_argument('-iptr', '--input-path-train', 63 | type=str, 64 | help='The path to the input dataset') 65 | 66 | parser.add_argument('-lptr', '--label-path-train', 67 | type=str, 68 | help='The path to the label dataset') 69 | 70 | parser.add_argument('-ipv', '--input-path-val', 71 | type=str, 72 | help='The path to the input dataset') 73 | 74 | parser.add_argument('-lpv', '--label-path-val', 75 | type=str, 76 | help='The path to the label dataset') 77 | 78 | parser.add_argument('-iptt', '--input-path-test', 79 | type=str, 80 | help='The path to the input dataset') 81 | 82 | parser.add_argument('-lptt', '--label-path-test', 83 | type=str, 84 | help='The path to the label dataset') 85 | 86 | parser.add_argument('-pe', '--print-every', 87 | type=int, 88 | default=1, 89 | help='The number of epochs after which to print the training loss') 90 | 91 | parser.add_argument('-ee', '--eval-every', 92 | type=int, 93 | default=10, 94 | help='The number of epochs after which to print the validation loss') 95 | 96 | parser.add_argument('--cuda', 97 | type=bool, 98 | default=False, 99 | help='Whether to use cuda or not') 100 | 101 | parser.add_argument('--mode', 102 | choices=['train', 'test'], 103 | default='train', 104 | help='Whether to train or test') 105 | 106 | parser.add_argument('-dt', '--dtype', 107 | choices=['cityscapes', 'pascal'], 108 | default='pascal', 109 | help='specify the dataset you are using') 110 | 111 | parser.add_argument('--scheduler', 112 | type=bool, 113 | default=False, 114 | help='Whether to use scheduler or not') 115 | 116 | parser.add_argument('--save', 117 | type=bool, 118 | default=True, 119 | help='Save the segmented image when predicting') 120 | 121 | FLAGS, unparsed = parser.parse_known_args() 122 | 123 | FLAGS.cuda = torch.device('cuda:0' if torch.cuda.is_available() and FLAGS.cuda \ 124 | else 'cpu') 125 | 126 | print ('[INFO]Arguments read successfully!') 127 | 128 | if FLAGS.mode.lower() == 'train': 129 | print ('[INFO]Train Mode.') 130 | 131 | if FLAGS.iptr == None or FLAGS.ipv == None: 132 | raise ('Error: Kindly provide the path to the dataset') 133 | 134 | train(FLAGS) 135 | 136 | elif FLAGS.mode.lower() == 'test': 137 | print ('[INFO]Predict Mode.') 138 | predict(FLAGS) 139 | else: 140 | raise RuntimeError('Unknown mode passed. \n Mode passed should be either \ 141 | of "train" or "test"') 142 | -------------------------------------------------------------------------------- /models/__pycache__/assp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AvivSham/DeepLabv3/c718ec3f8190ca2fc45a52a121e9009ca8284e2f/models/__pycache__/assp.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/deeplabv3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AvivSham/DeepLabv3/c718ec3f8190ca2fc45a52a121e9009ca8284e2f/models/__pycache__/deeplabv3.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet_50.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AvivSham/DeepLabv3/c718ec3f8190ca2fc45a52a121e9009ca8284e2f/models/__pycache__/resnet_50.cpython-36.pyc -------------------------------------------------------------------------------- /models/assp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ASSP(nn.Module): 6 | def __init__(self,in_channels,out_channels = 256): 7 | super(ASSP,self).__init__() 8 | 9 | 10 | self.relu = nn.ReLU(inplace=True) 11 | 12 | self.conv1 = nn.Conv2d(in_channels = in_channels, 13 | out_channels = out_channels, 14 | kernel_size = 1, 15 | padding = 0, 16 | dilation=1, 17 | bias=False) 18 | 19 | self.bn1 = nn.BatchNorm2d(out_channels) 20 | 21 | self.conv2 = nn.Conv2d(in_channels = in_channels, 22 | out_channels = out_channels, 23 | kernel_size = 3, 24 | stride=1, 25 | padding = 6, 26 | dilation = 6, 27 | bias=False) 28 | 29 | self.bn2 = nn.BatchNorm2d(out_channels) 30 | 31 | self.conv3 = nn.Conv2d(in_channels = in_channels, 32 | out_channels = out_channels, 33 | kernel_size = 3, 34 | stride=1, 35 | padding = 12, 36 | dilation = 12, 37 | bias=False) 38 | 39 | self.bn3 = nn.BatchNorm2d(out_channels) 40 | 41 | self.conv4 = nn.Conv2d(in_channels = in_channels, 42 | out_channels = out_channels, 43 | kernel_size = 3, 44 | stride=1, 45 | padding = 18, 46 | dilation = 18, 47 | bias=False) 48 | 49 | self.bn4 = nn.BatchNorm2d(out_channels) 50 | 51 | self.conv5 = nn.Conv2d(in_channels = in_channels, 52 | out_channels = out_channels, 53 | kernel_size = 1, 54 | stride=1, 55 | padding = 0, 56 | dilation=1, 57 | bias=False) 58 | 59 | self.bn5 = nn.BatchNorm2d(out_channels) 60 | 61 | self.convf = nn.Conv2d(in_channels = out_channels * 5, 62 | out_channels = out_channels, 63 | kernel_size = 1, 64 | stride=1, 65 | padding = 0, 66 | dilation=1, 67 | bias=False) 68 | 69 | self.bnf = nn.BatchNorm2d(out_channels) 70 | 71 | self.adapool = nn.AdaptiveAvgPool2d(1) 72 | 73 | 74 | def forward(self,x): 75 | 76 | x1 = self.conv1(x) 77 | x1 = self.bn1(x1) 78 | x1 = self.relu(x1) 79 | 80 | x2 = self.conv2(x) 81 | x2 = self.bn2(x2) 82 | x2 = self.relu(x2) 83 | 84 | x3 = self.conv3(x) 85 | x3 = self.bn3(x3) 86 | x3 = self.relu(x3) 87 | 88 | x4 = self.conv4(x) 89 | x4 = self.bn4(x4) 90 | x4 = self.relu(x4) 91 | 92 | x5 = self.adapool(x) 93 | x5 = self.conv5(x5) 94 | x5 = self.bn5(x5) 95 | x5 = self.relu(x5) 96 | x5 = F.interpolate(x5, size = tuple(x4.shape[-2:]), mode='bilinear') 97 | 98 | x = torch.cat((x1,x2,x3,x4,x5), dim = 1) #channels first 99 | x = self.convf(x) 100 | x = self.bnf(x) 101 | x = self.relu(x) 102 | 103 | return x 104 | -------------------------------------------------------------------------------- /models/deeplabv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .assp import ASSP 6 | from .resnet_50 import ResNet_50 7 | 8 | class DeepLabv3(nn.Module): 9 | 10 | def __init__(self, nc): 11 | 12 | super(DeepLabv3, self).__init__() 13 | 14 | self.nc = nc 15 | 16 | self.resnet = ResNet_50() 17 | 18 | self.assp = ASSP(in_channels = 1024) 19 | 20 | self.conv = nn.Conv2d(in_channels = 256, out_channels = self.nc, 21 | kernel_size = 1, stride=1, padding=0) 22 | 23 | def forward(self,x): 24 | _, _, h, w = x.shape 25 | x = self.resnet(x) 26 | x = self.assp(x) 27 | x = self.conv(x) 28 | x = F.interpolate(x, size=(h, w), mode='bilinear') #scale_factor = 16, mode='bilinear') 29 | return x 30 | -------------------------------------------------------------------------------- /models/resnet_50.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | 5 | class ResNet_50 (nn.Module): 6 | def __init__(self, in_channels = 3, conv1_out = 64): 7 | super(ResNet_50,self).__init__() 8 | 9 | self.resnet_50 = models.resnet50(pretrained = True) 10 | 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self,x): 14 | x = self.relu(self.resnet_50.bn1(self.resnet_50.conv1(x))) 15 | x = self.resnet_50.maxpool(x) 16 | x = self.resnet_50.layer1(x) 17 | x = self.resnet_50.layer2(x) 18 | x = self.resnet_50.layer3(x) 19 | 20 | return x 21 | -------------------------------------------------------------------------------- /results/CityScapes/README.md: -------------------------------------------------------------------------------- 1 | ## Results while training on the CityScapes Dataset 2 | 3 | Do note: This doesn't contain all the results produced during training on the dataset. Just some along the way, how it looks as the training 4 | proceeds. 5 | 6 | ## After 50 iteration with `batch_size=2` 7 | 8 | Input: 9 | ![Input](https://lh3.googleusercontent.com/-8mSdHCZ_Y-0/XHp4n9nXSDI/AAAAAAAAAWY/sN8vbOmBJb07uVghvWOAzlJ7OYqxozdHwCL0BGAYYCw/h269/2019-03-02.png) 10 | 11 | Activations: 12 | ![Activations](https://lh3.googleusercontent.com/-JFbreEQT2Yw/XHp0ZGG5fwI/AAAAAAAAAWI/R3zH5dzHddwTPeOQw3HNfWRE5ij9JU9vACL0BGAYYCw/h578/2019-03-02.png) 13 | 14 | ## After 100 iteration with `batch_size=2` 15 | 16 | Input: 17 | ![Input](https://lh3.googleusercontent.com/-w3VwzCuArKc/XHp4sxUwaWI/AAAAAAAAAWg/lLgT0Vl1GXkOLOCGTXZ4rUNIRX8eSnxUQCL0BGAYYCw/h269/2019-03-02.png) 18 | 19 | Activations: 20 | ![Activations](https://lh3.googleusercontent.com/-8mJ2l1lKCv4/XHp4u1jbNhI/AAAAAAAAAWo/DMnuEfs5Ls8thBWTezkGVipS3zBqolpggCL0BGAYYCw/h578/2019-03-02.png) 21 | 22 | ## After 150 iteration with `batch_size=2` 23 | 24 | Input: 25 | ![Input](https://lh3.googleusercontent.com/-D7oQvDWQBrk/XHp5GVzfDwI/AAAAAAAAAWs/rd_ORpIrh-U90OptEn-pszKH43QrAsLcgCL0BGAYYCw/h269/2019-03-02.png) 26 | 27 | Activations: 28 | ![Activations](https://lh3.googleusercontent.com/-1l4Siwhn8po/XHp5Iatr1PI/AAAAAAAAAWw/N-ab8qAByhgBG7GcmPd6Kz_VyjxPZS90QCL0BGAYYCw/h578/2019-03-02.png) 29 | 30 | ## After 200 iteration with `batch_size=2` 31 | 32 | Input: 33 | ![Input](https://lh3.googleusercontent.com/-oNCQMuN7i70/XHp5XSztZYI/AAAAAAAAAW0/L8zjuUeubncQv9MbcB6__CUNWl9IIEVKwCL0BGAYYCw/h269/2019-03-02.png) 34 | 35 | Activations: 36 | ![Activations](https://lh3.googleusercontent.com/-yRkwEJ77_1U/XHp5Zs0oZOI/AAAAAAAAAW4/XBlO9btzNPs_EuJv8OMPSk1ucRMV3NRFACL0BGAYYCw/h578/2019-03-02.png) 37 | 38 | ## After 250 iteration with `batch_size=2` 39 | 40 | Input: 41 | ![Input](https://lh3.googleusercontent.com/-DtCCEXFwIYI/XHp5cS7Bn4I/AAAAAAAAAW8/QZZ2YwyHLUcr6Wp7kQfwa8HGmrub5sVOQCL0BGAYYCw/h269/2019-03-02.png) 42 | 43 | Activations: 44 | ![Activations](https://lh3.googleusercontent.com/-nQ8DuYoow98/XHp5fZh91II/AAAAAAAAAXA/Ho70JLqmvngmc_gy5MRht9obzNJ5W_1-QCL0BGAYYCw/h578/2019-03-02.png) 45 | 46 | ## After 1000 iteration with `batch_size=2` 47 | 48 | Input: 49 | ![Input](https://plus.google.com/u/0/photos/albums/pcvhpi9bsgugshktn44ms9i9rm9gnjk34?pid=6663791313794764578&oid=108416174423949606030) 50 | 51 | Activations: 52 | ![Activations](https://lh3.googleusercontent.com/-FSuuz4f-3zM/XHqK7JUn9_I/AAAAAAAAAXQ/7iOsGsuWINIdbubacaHXPhtELT-enJKUACL0BGAYYCw/h578/2019-03-02.png) 53 | 54 | ## After 1100 iteration with `batch_size=2` 55 | 56 | Input: 57 | ![Input](https://lh3.googleusercontent.com/-b-Bsz-X4Zcg/XHqLry3Ce1I/AAAAAAAAAXY/TxggAuMn3r0V-O0pSTLXf7k8YQA6NictgCL0BGAYYCw/h269/2019-03-02.png) 58 | 59 | Activations: 60 | ![Activations](https://lh3.googleusercontent.com/--lQPxg8U0aM/XHqLupFab6I/AAAAAAAAAXc/WWcxbJuijBQIc9e1NBoH6tc_RGTQbtCHQCL0BGAYYCw/h578/2019-03-02.png) 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /results/pascal voc 2012/README.md: -------------------------------------------------------------------------------- 1 | ## Results while training on the Pascal VOC 2012 Dataset 2 | 3 | Do note: This doesn't contain all the results produced during training on the dataset. Just some along the way, how it looks as the training proceeds. 4 | 5 | ## After 10 epochs with `batch_size=16` 6 | 7 | Input: 8 | 9 | ![inp](https://github.com/AvivSham/DeepLabv3/blob/master/results/pascal%20voc%202012/epoch_10.png) 10 | 11 | Activation: 12 | 13 | ![inp](https://github.com/AvivSham/DeepLabv3/blob/master/results/pascal%20voc%202012/epoch_10_seg.png) 14 | 15 | ## After 20 epochs with `batch_size=16` 16 | 17 | Input: 18 | 19 | ![inp](https://github.com/AvivSham/DeepLabv3/blob/master/results/pascal%20voc%202012/epoch_20.png) 20 | 21 | Activation: 22 | 23 | ![inp](https://github.com/AvivSham/DeepLabv3/blob/master/results/pascal%20voc%202012/epoch_20_seg.png) 24 | 25 | -------------------------------------------------------------------------------- /results/pascal voc 2012/epoch_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AvivSham/DeepLabv3/c718ec3f8190ca2fc45a52a121e9009ca8284e2f/results/pascal voc 2012/epoch_10.png -------------------------------------------------------------------------------- /results/pascal voc 2012/epoch_10_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AvivSham/DeepLabv3/c718ec3f8190ca2fc45a52a121e9009ca8284e2f/results/pascal voc 2012/epoch_10_seg.png -------------------------------------------------------------------------------- /results/pascal voc 2012/epoch_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AvivSham/DeepLabv3/c718ec3f8190ca2fc45a52a121e9009ca8284e2f/results/pascal voc 2012/epoch_20.png -------------------------------------------------------------------------------- /results/pascal voc 2012/epoch_20_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AvivSham/DeepLabv3/c718ec3f8190ca2fc45a52a121e9009ca8284e2f/results/pascal voc 2012/epoch_20_seg.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils import * 4 | from models.deeplabv3 import DeepLabv3 5 | import sys 6 | import os 7 | import time 8 | from tqdm import tqdm 9 | from PIL import Image 10 | import matplotlib.pyplot as plt 11 | import matplotlib.gridspec as gridspec 12 | 13 | def predict(FLAGS): 14 | # Check if the pretrained model is available 15 | if not FLAGS.m.endswith('.pth'): 16 | raise RuntimeError('Unknown file passed. Must end with .pth') 17 | if FLAGS.image_path is None or not os.path.exists(FLAGS.image_path): 18 | raise RuntimeError('An image file path must be passed') 19 | 20 | h = FLAGS.resize_height 21 | w = FLAGS.resize_width 22 | 23 | print ('[INFO]Loading Checkpoint...') 24 | checkpoint = torch.load(FLAGS.m, map_location='cpu') 25 | print ('[INFO]Checkpoint Loaded') 26 | 27 | # Assuming the dataset is camvid 28 | deeplabv3 = DeepLabv3(FLAGS.num_classes) 29 | deeplabv3.load_state_dict(checkpoint['model_state_dict']) 30 | print ('[INFO]Initiated model with pretraiend weights.') 31 | 32 | tmg_ = np.array(Image.open(FLAGS.image_path)) 33 | tmg_ = cv2.resize(tmg_, (w, h), cv2.INTER_NEAREST) 34 | tmg = torch.tensor(tmg_).unsqueeze(0).float() 35 | tmg = tmg.transpose(2, 3).transpose(1, 2) 36 | 37 | print ('[INFO]Starting inference...') 38 | deeplabv3.eval() 39 | s = time.time() 40 | out1 = deeplabv3(tmg.float()).squeeze(0) 41 | o = time.time() 42 | deeplabv3.train() 43 | print ('[INFO]Inference complete!') 44 | print ('[INFO]Time taken: ', o - s) 45 | 46 | out2 = out1.squeeze(0).cpu().detach().numpy() 47 | 48 | b_ = out1.data.max(0)[1].cpu().detach().numpy() 49 | 50 | b = decode_segmap_cscapes(b_) 51 | print ('[INFO]Got segmented results!') 52 | 53 | plt.title('Input Image') 54 | plt.axis('off') 55 | plt.imshow(tmg_) 56 | plt.show() 57 | 58 | plt.title('Output Image') 59 | plt.axis('off') 60 | plt.imshow(b) 61 | plt.show() 62 | 63 | plt.figure(figsize=(10, 10)) 64 | gs = gridspec.GridSpec(9, 4) 65 | gs.update(wspace=0.025, hspace=0.005) 66 | 67 | label = 0 68 | for ii in range(34): 69 | plt.subplot(gs[ii]) 70 | plt.axis('off') 71 | plt.imshow(out2[label, :, :]) 72 | label += 1 73 | plt.show() 74 | 75 | if FLAGS.save: 76 | cv2.imwrite('seg.png', b) 77 | print ('[INFO]Segmented image saved successfully!') 78 | 79 | print ('[INFO] Prediction complete successfully!') 80 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils import * 4 | from models.deeplabv3 import DeepLabv3 5 | import sys 6 | from tqdm import tqdm 7 | 8 | def train(FLAGS): 9 | 10 | # Defining the hyperparameters 11 | device = FLAGS.cuda 12 | batch_size = FLAGS.batch_size 13 | epochs = FLAGS.epochs 14 | lr = FLAGS.learning_rate 15 | print_every = FLAGS.print_every 16 | eval_every = FLAGS.eval_every 17 | save_every = FLAGS.save_every 18 | nc = FLAGS.num_classes 19 | wd = FLAGS.weight_decay 20 | 21 | ip = FLAGS.input_path_train 22 | lp = FLAGS.label_path_train 23 | 24 | ipv = FLAGS.input_path_val 25 | lpv = FLAGS.label_path_val 26 | 27 | H = FLAGS.resize_height 28 | W = FLAGS.resize_width 29 | 30 | dtype = FLAGS.dtype 31 | sched = FLAGS.scheduler 32 | 33 | if FLAGS.dtype == 'cityscapes': 34 | train_samples = len(glob.glob(ip + '/**/*.png', recursive=True)) 35 | eval_samples = len(glob.glob(lp + '/**/*.png', recursive=True)) 36 | elif FLAGS.dtype == 'pascal': 37 | train_samples = len(os.listdir(lp)) 38 | eval_samples = len(os.listdir(lp)) 39 | 40 | print ('[INFO]Defined all the hyperparameters successfully!') 41 | 42 | # Get the class weights 43 | #print ('[INFO]Starting to define the class weights...') 44 | #pipe = loader(ip, lp, batch_size='all') 45 | #class_weights = get_class_weights(pipe, nc) 46 | #print ('[INFO]Fetched all class weights successfully!') 47 | 48 | # Get an instance of the model 49 | model = DeepLabv3(nc) 50 | print ('[INFO]Model Instantiated!') 51 | 52 | # Move the model to cuda if available 53 | model.to(device) 54 | 55 | # Define the criterion and the optimizer 56 | #criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights).to(device)) 57 | criterion = nn.CrossEntropyLoss() 58 | optimizer = torch.optim.Adam(model.parameters(), 59 | lr=lr, 60 | weight_decay=wd) 61 | print ('[INFO]Defined the loss function and the optimizer') 62 | 63 | # Training Loop starts 64 | print ('[INFO]Staring Training...') 65 | print () 66 | 67 | train_losses = [] 68 | eval_losses = [] 69 | 70 | if dtype == 'cityscapes': 71 | pipe = loader_cscapes(ip, lp, batch_size, h = H, w = W) 72 | elif dtype == 'pascal': 73 | pipe = loader(ip, lp, batch_size, h = H, w = W) 74 | #eval_pipe = loader(ipv, lpv, batch_size) 75 | 76 | show_every = 250 77 | 78 | train_losses = [] 79 | eval_losses = [] 80 | 81 | bc_train = train_samples // batch_size 82 | bc_eval = eval_samples // batch_size 83 | 84 | if sched: 85 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: (1 - (epoch / epochs)) ** 0.9) 86 | 87 | for e in range(1, epochs+1): 88 | 89 | train_loss = 0 90 | print ('-'*15,'Epoch %d' % e, '-'*15) 91 | 92 | if sched: 93 | scheduler.step() 94 | 95 | model.train() 96 | 97 | for ii in tqdm(range(bc_train)): 98 | X_batch, mask_batch = next(pipe) 99 | 100 | X_batch, mask_batch = X_batch.to(device), mask_batch.to(device) 101 | 102 | optimizer.zero_grad() 103 | 104 | out = model(X_batch.float()) 105 | 106 | loss = criterion(out, mask_batch.long()) 107 | loss.backward() 108 | optimizer.step() 109 | 110 | train_loss += loss.item() 111 | 112 | if ii % show_every == 0: 113 | out5 = show_cscpaes(model, H, W) 114 | checkpoint = { 115 | 'epochs' : e, 116 | 'model_state_dict' : model.state_dict(), 117 | 'opt_state_dict' : optimizer.state_dict() 118 | } 119 | torch.save(checkpoint, './ckpt-dlabv3-{}-{:2f}.pth'.format(e, train_loss)) 120 | print ('Model saved!') 121 | 122 | print () 123 | train_losses.append(train_loss) 124 | 125 | if (e+1) % print_every == 0: 126 | print ('Epoch {}/{}...'.format(e, epochs), 127 | 'Loss {:6f}'.format(train_loss)) 128 | 129 | if e % save_every == 0: 130 | 131 | show_pascal(model, training_path, all_tests[np.random.randint(0, len(all_tests))]) 132 | checkpoint = { 133 | 'epochs' : e, 134 | 'state_dict' : model.state_dict() 135 | } 136 | torch.save(checkpoint, '/content/ckpt-enet-{}-{:2f}.pth'.format(e, train_loss)) 137 | print ('Model saved!') 138 | 139 | 140 | # show(model, all_tests[np.random.randint(0, len(all_tests))]) 141 | # show_pascal(model, training_path, all_tests[np.random.randint(0, len(all_tests))]) 142 | 143 | print ('[INFO]Training Process complete!') 144 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import matplotlib.pyplot as plt 4 | import os 5 | from PIL import Image 6 | import torch 7 | 8 | def create_class_mask(img, color_map, is_normalized_img=True, is_normalized_map=False, show_masks=False): 9 | """ 10 | Function to create C matrices from the segmented image, where each of the C matrices is for one class 11 | with all ones at the pixel positions where that class is present 12 | 13 | img = The segmented image 14 | 15 | color_map = A list with tuples that contains all the RGB values for each color that represents 16 | some class in that image 17 | 18 | is_normalized_img = Boolean - Whether the image is normalized or not 19 | If normalized, then the image is multiplied with 255 20 | 21 | is_normalized_map = Boolean - Represents whether the color map is normalized or not, if so 22 | then the color map values are multiplied with 255 23 | 24 | show_masks = Wherether to show the created masks or not 25 | """ 26 | 27 | if is_normalized_img and (not is_normalized_map): 28 | img *= 255 29 | 30 | if is_normalized_map and (not is_normalized_img): 31 | img = img / 255 32 | 33 | mask = [] 34 | hw_tuple = img.shape[:-1] 35 | for color in color_map: 36 | color_img = [] 37 | for idx in range(3): 38 | color_img.append(np.ones(hw_tuple) * color[idx]) 39 | 40 | color_img = np.array(color_img, dtype=np.uint8).transpose(1, 2, 0) 41 | 42 | mask.append(np.uint8((color_img == img).sum(axis = -1) == 3)) 43 | 44 | return np.array(mask) 45 | 46 | 47 | # Cityscapes dataset Loader 48 | 49 | def loader_cscapes(input_path, segmented_path, batch_size, h=1024, w=2048, limited=False): 50 | filenames_t = sorted(glob.glob(input_path + '/**/*.png', recursive=True), key=lambda x : int(x.split('/')[-1].split('_')[1] + x.split('/')[-1].split('_')[2])) 51 | total_files_t = len(filenames_t) 52 | 53 | filenames_s = sorted(glob.glob(segmented_path + '/**/*labelIds.png', recursive=True), key=lambda x : int(x.split('/')[-1].split('_')[1] + x.split('/')[-1].split('_')[2])) 54 | 55 | total_files_s = len(filenames_s) 56 | 57 | assert(total_files_t == total_files_s) 58 | 59 | batches = np.random.permutation(np.arange(total_files_s)) 60 | idx0 = 0 61 | idx1 = idx0 + batch_size 62 | 63 | if str(batch_size).lower() == 'all': 64 | batch_size = total_files_s 65 | 66 | idx = 1 if not limited else total_files_s // batch_size + 1 67 | while(idx): 68 | 69 | batch = np.arange(idx0, idx1) 70 | 71 | # Choosing random indexes of images and labels 72 | batch_idxs = np.random.randint(0, total_files_s, batch_size) 73 | 74 | inputs = [] 75 | labels = [] 76 | 77 | for jj in batch_idxs: 78 | # Reading normalized photo 79 | img = np.array(Image.open(filenames_t[jj])) 80 | # Resizing using nearest neighbor method 81 | inputs.append(img) 82 | 83 | # Reading semantic image 84 | img = Image.open(filenames_s[jj]) 85 | img = np.array(img) 86 | # Resizing using nearest neighbor method 87 | labels.append(img) 88 | 89 | inputs = np.stack(inputs, axis=2) 90 | # Changing image format to C x H x W 91 | inputs = torch.tensor(inputs).transpose(0, 2).transpose(1, 3) 92 | 93 | labels = torch.tensor(labels) 94 | 95 | idx0 = idx1 if idx1 + batch_size < total_files_s else 0 96 | idx1 = idx0 + batch_size 97 | 98 | if limited: 99 | idx -= 1 100 | 101 | yield inputs, labels 102 | 103 | def loader(training_path, segmented_path, batch_size, h=512, w=512): 104 | """ 105 | The Loader to generate inputs and labels from the Image and Segmented Directory 106 | 107 | Arguments: 108 | 109 | training_path - str - Path to the directory that contains the training images 110 | 111 | segmented_path - str - Path to the directory that contains the segmented images 112 | 113 | batch_size - int - the batch size 114 | 115 | yields inputs and labels of the batch size 116 | """ 117 | 118 | filenames_t = os.listdir(training_path) 119 | total_files_t = len(filenames_t) 120 | 121 | filenames_s = os.listdir(segmented_path) 122 | total_files_s = len(filenames_s) 123 | 124 | assert(total_files_t == total_files_s) 125 | 126 | if str(batch_size).lower() == 'all': 127 | batch_size = total_files_s 128 | 129 | idx = 0 130 | while(1): 131 | batch_idxs = np.random.randint(0, total_files_s, batch_size) 132 | 133 | inputs = [] 134 | labels = [] 135 | 136 | for jj in batch_idxs: 137 | img = plt.imread(training_path + filenames_t[jj]) 138 | img = cv2.resize(img, (h, w), cv2.INTER_NEAREST) 139 | inputs.append(img) 140 | 141 | img = Image.open(segmented_path + filenames_s[jj]) 142 | img = np.array(img) 143 | img = cv2.resize(img, (h, w), cv2.INTER_NEAREST) 144 | labels.append(img) 145 | 146 | inputs = np.stack(inputs, axis=2) 147 | inputs = torch.tensor(inputs).transpose(0, 2).transpose(1, 3) 148 | 149 | labels = torch.tensor(labels) 150 | 151 | yield inputs, labels 152 | 153 | 154 | def decode_segmap_camvid(image): 155 | Sky = [128, 128, 128] 156 | Building = [128, 0, 0] 157 | Pole = [192, 192, 128] 158 | Road_marking = [255, 69, 0] 159 | Road = [128, 64, 128] 160 | Pavement = [60, 40, 222] 161 | Tree = [128, 128, 0] 162 | SignSymbol = [192, 128, 128] 163 | Fence = [64, 64, 128] 164 | Car = [64, 0, 128] 165 | Pedestrian = [64, 64, 0] 166 | Bicyclist = [0, 128, 192] 167 | 168 | label_colors = np.array([Sky, Building, Pole, Road_marking, Road, 169 | Pavement, Tree, SignSymbol, Fence, Car, 170 | Pedestrian, Bicyclist]).astype(np.uint8) 171 | 172 | r = np.zeros_like(image).astype(np.uint8) 173 | g = np.zeros_like(image).astype(np.uint8) 174 | b = np.zeros_like(image).astype(np.uint8) 175 | 176 | for label in range(len(label_colors)): 177 | r[image == label] = label_colors[label, 0] 178 | g[image == label] = label_colors[label, 1] 179 | b[image == label] = label_colors[label, 2] 180 | 181 | rgb = np.zeros((image.shape[0], image.shape[1], 3)).astype(np.uint8) 182 | rgb[:, :, 0] = r 183 | rgb[:, :, 1] = g 184 | rgb[:, :, 2] = b 185 | 186 | return rgb 187 | 188 | def decode_segmap_cscapes(image, nc=34): 189 | 190 | label_colours = np.array([(0, 0, 0), # 0=background 191 | (0, 0, 0), # 1=ego vehicle 192 | (0, 0, 0), # 2=rectification border 193 | (0, 0, 0), # 3=out of toi 194 | (0, 0, 0), # 4=static 195 | # 5=dynamic, 6=ground, 7=road, 8=sidewalk, 9=parking 196 | (111, 74, 0), ( 81, 0, 81), (128, 64,128), (244, 35,232), (250,170,160), 197 | # 10=rail track, 11=building, 12=wall, 13=fence, 14=guard rail 198 | (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153), (180,165,180), 199 | # 15=bridge, 16=tunnel, 17=pole, 18=pole group, 19=traffic light 200 | (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), 201 | # 20=traffic sign, 21=vegetation, 22=terrain, 23=sky, 24=person 202 | (220,220, 0), (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), 203 | # 25=rider, 26=car, 27=truck, 28=bus, 29=caravan, 204 | (255, 0, 0), ( 0, 0,142), ( 0, 0, 70), ( 0, 60,100), ( 0, 0, 90), 205 | # 30=trailer, 31=train, 32=motorcycle, 33=bicycle, 34=license plate, 206 | ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142), 207 | ]) 208 | 209 | r = np.zeros_like(image).astype(np.uint8) 210 | g = np.zeros_like(image).astype(np.uint8) 211 | b = np.zeros_like(image).astype(np.uint8) 212 | 213 | for l in range(0, nc): 214 | r[image == l] = label_colours[l, 0] 215 | g[image == l] = label_colours[l, 1] 216 | b[image == l] = label_colours[l, 2] 217 | 218 | rgb = np.zeros((image.shape[0], image.shape[1], 3)).astype(np.uint8) 219 | rgb[:, :, 0] = b 220 | rgb[:, :, 1] = g 221 | rgb[:, :, 2] = r 222 | return rgb 223 | 224 | def show_images(images, in_row=True): 225 | ''' 226 | Helper function to show 3 images 227 | ''' 228 | total_images = len(images) 229 | 230 | rc_tuple = (1, total_images) 231 | if not in_row: 232 | rc_tuple = (total_images, 1) 233 | 234 | #figure = plt.figure(figsize=(20, 10)) 235 | for ii in range(len(images)): 236 | plt.subplot(*rc_tuple, ii+1) 237 | plt.title(images[ii][0]) 238 | plt.axis('off') 239 | plt.imshow(images[ii][1]) 240 | plt.show() 241 | 242 | def get_class_weights(loader, num_classes, c=1.02): 243 | ''' 244 | This class return the class weights for each class 245 | 246 | Arguments: 247 | - loader : The generator object which return all the labels at one iteration 248 | Do Note: That this class expects all the labels to be returned in 249 | one iteration 250 | 251 | - num_classes : The number of classes 252 | 253 | Return: 254 | - class_weights : An array equal in length to the number of classes 255 | containing the class weights for each class 256 | ''' 257 | 258 | _, labels = next(loader) 259 | all_labels = labels.flatten() 260 | each_class = np.bincount(all_labels, minlength=num_classes) 261 | prospensity_score = each_class / len(all_labels) 262 | class_weights = 1 / (np.log(c + prospensity_score)) 263 | return class_weights 264 | --------------------------------------------------------------------------------