├── .gitignore
├── LICENSE
├── README.md
├── checkpoint
└── .gitkeep
├── datasets
├── README.md
├── cityscapes
│ ├── gtCoarse
│ │ └── .gitkeep
│ ├── gtFine
│ │ └── .gitkeep
│ ├── leftImg8bit
│ │ └── .gitkeep
│ └── results
│ │ └── .gitkeep
└── cityscapesscripts
│ └── .gitkeep
├── imagenet-pretrain
├── README.md
├── lednet_imagenet.py
└── main.py
├── images
├── LEDNet_demo.png
└── LEDNet_overview.png
├── requirements.txt
├── save
└── .gitkeep
├── test
├── README.md
├── dataset.py
├── eval_cityscapes_color.py
├── eval_cityscapes_server.py
├── eval_forward_time.py
├── eval_iou.py
├── iouEval.py
├── lednet_no_bn.py
└── transform.py
├── train
├── README.md
├── lednet.py
├── lednet_1.py
└── main.py
└── utils
├── __init__.py
├── dataset.py
├── iouEval.py
├── loss.py
├── transform.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | #Files
2 | *.pyc
3 | *.pyo
4 | */__pycache__/
5 | */*/__pycache__/
6 | */*/*/__pycache__/
7 | test/save_results/
8 | test/save_color/
9 | .idea/
10 |
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Yu Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ### [LEDNet: A Lightweight Encoder-Decoder Network for Real-time Semantic Segmentation](https://github.com/xiaoyufenfei/LEDNet)
2 |
3 | [![python-image]][python-url]
4 | [![pytorch-image]][pytorch-url]
5 |
6 | #### Table of Contents:
7 | - Introduction
8 | - Project Structure
9 | - Installation
10 | - Datasets
11 | - Train
12 | - Resuming training
13 | - Test
14 | - Results
15 | - Reference
16 | - Tips
17 |
18 | #### Introduction
19 |
20 | This project contains the code (Note: The code is test in the environment with python=3.6, cuda=9.0, PyTorch-0.4.1, also support Pytorch-0.4.1+) for: [**LEDNet: A Lightweight Encoder-Decoder Network for Real-time Semantic Segmentation**](https://arxiv.org/pdf/1905.02423.pdf) by [Yu Wang](https://github.com/xiaoyufenfei).
21 |
22 |

23 | The extensive computational burden limits the usage of CNNs in mobile devices for dense estimation tasks, a.k.a semantic segmentation. In this paper, we present a lightweight network to address this problem, namely **LEDNet**, which employs an asymmetric encoder-decoder architecture for the task of real-time semantic segmentation.More specifically, the encoder adopts a ResNet as backbone network, where two new operations, channel split and shuffle, are utilized in each residual block to greatly reduce computation cost while maintaining higher segmentation accuracy. On the other hand, an attention pyramid network (APN) is employed in the decoder to further lighten the entire network complexity. Our model has less than 1M parameters, and is able to run at over 71 FPS on a single GTX 1080Ti GPU card. The comprehensive experiments demonstrate that our approach achieves state-of-the-art results in terms of speed and accuracy trade-off on Cityscapes dataset. and becomes an effective method for real-time semantic segmentation tasks.
24 |
25 | #### Project-Structure
26 | ```
27 | ├── datasets # contains all datasets for the project
28 | | └── cityscapes # cityscapes dataset
29 | | | └── gtCoarse # Coarse cityscapes annotation
30 | | | └── gtFine # Fine cityscapes annotation
31 | | | └── leftImg8bit # cityscapes training image
32 | | └── cityscapesscripts # cityscapes dataset label convert scripts!
33 | ├── utils
34 | | └── dataset.py # dataloader for cityscapes dataset
35 | | └── iouEval.py # for test 'iou mean' and 'iou per class'
36 | | └── transform.py # data preprocessing
37 | | └── visualize.py # Visualize with visdom
38 | | └── loss.py # loss function
39 | ├── checkpoint
40 | | └── xxx.pth # pretrained models encoder form ImageNet
41 | ├── save
42 | | └── xxx.pth # trained models form scratch
43 | ├── imagenet-pretrain
44 | | └── lednet_imagenet.py #
45 | | └── main.py #
46 | ├── train
47 | | └── lednet.py # model definition for semantic segmentation
48 | | └── main.py # train model scripts
49 | ├── test
50 | | | └── dataset.py
51 | | | └── lednet.py # model definition
52 | | | └── lednet_no_bn.py # Remove the BN layer in model definition
53 | | | └── eval_cityscapes_color.py # Test the results to generate RGB images
54 | | | └── eval_cityscapes_server.py # generate result uploaded official server
55 | | | └── eval_forward_time.py # Test model inference time
56 | | | └── eval_iou.py
57 | | | └── iouEval.py
58 | | | └── transform.py
59 | ```
60 |
61 | #### Installation
62 | - Python 3.6.x. Recommended using [Anaconda3](https://www.anaconda.com/distribution/)
63 | - Set up python environment
64 |
65 | ```
66 | pip3 install -r requirements.txt
67 | ```
68 |
69 | - Env: PyTorch_0.4.1; cuda_9.0; cudnn_7.1; python_3.6,
70 |
71 | - Clone this repository.
72 |
73 | ```
74 | git clone https://github.com/xiaoyufenfei/LEDNet.git
75 | cd LEDNet-master
76 | ```
77 |
78 | - Install [Visdom](https://github.com/facebookresearch/visdom).
79 | - Install [torchsummary](https://github.com/sksq96/pytorch-summary)
80 | - Download the dataset by following the **Datasets** below.
81 | - Note: For training, we currently support [cityscapes](https://www.cityscapes-dataset.com/) , aim to add [Camvid](https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid) and [VOC](http://host.robots.ox.ac.uk/pascal/VOC/) and [ADE20K](http://groups.csail.mit.edu/vision/datasets/ADE20K/) dataset
82 |
83 | #### Datasets
84 | - You can download [cityscapes](https://www.cityscapes-dataset.com/) from [here](https://www.cityscapes-dataset.com/downloads/). Note: please download [leftImg8bit_trainvaltest.zip(11GB)](https://www.cityscapes-dataset.com/file-handling/?packageID=4) and [gtFine_trainvaltest(241MB)](https://www.cityscapes-dataset.com/file-handling/?packageID=1) and [gtCoarse(1.3GB)](https://www.cityscapes-dataset.com/file-handling/?packageID=1).
85 | - You can download [CityscapesScripts](https://github.com/mcordts/cityscapesScripts), and convert the dataset to [19 categories](https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py). It should have this basic structure.
86 |
87 | ```
88 | ├── leftImg8bit
89 | │ ├── train
90 | │ ├── val
91 | │ └── test
92 | ├── gtFine
93 | │ ├── train
94 | │ ├── val
95 | │ └── test
96 | ├── gtCoarse
97 | │ ├── train
98 | │ ├── train_extra
99 | │ └── val
100 | ```
101 |
102 | #### Training-LEDNet
103 |
104 | - For help on the optional arguments you can run: `python main.py -h`
105 |
106 | - By default, we assume you have downloaded the cityscapes dataset in the `./data/cityscapes` dir.
107 | - To train LEDNet using the train/main.py script the parameters listed in `main.py` as a flag or manually change them.
108 |
109 | ```
110 | python main.py --savedir logs --model lednet --datadir path/root_directory/ --num-epochs xx --batch-size xx ...
111 | ```
112 |
113 | #### Resuming-training-if-decoder-part-broken
114 |
115 | - for help on the optional arguments you can run: `python main.py -h`
116 |
117 | ```
118 | python main.py --savedir logs --name lednet --datadir path/root_directory/ --num-epochs xx --batch-size xx --decoder --state "../save/logs/model_best_enc.pth.tar"...
119 | ```
120 |
121 | #### Testing
122 |
123 | - the trained models of training process can be found at [here](https://github.com/xiaoyufenfei/LEDNet/save/). This may not be the best one, you can train one from scratch by yourself or Fine-tuning the training decoder with model encoder pre-trained on ImageNet, For instance
124 |
125 | ```
126 | more details refer ./test/README.md
127 | ```
128 |
129 | #### Results
130 |
131 | - Please refer to our article for more details.
132 |
133 | |Method|Dataset|Fine|Coarse| IoU_cla |IoU_cat|FPS|
134 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:|
135 | |**LEDNet**|**cityscapes**|yes|yes|**70.6%**|**87.1%**|**70+**|
136 |
137 | qualitative segmentation result examples:
138 |
139 | 
140 |
141 | #### Citation
142 |
143 | If you find this code useful for your research, please use the following BibTeX entry.
144 |
145 | ```
146 | @article{wang2019lednet,
147 | title={LEDNet: A Lightweight Encoder-Decoder Network for Real-time Semantic Segmentation},
148 | author={Wang, Yu and Zhou, Quan and Liu, Jia and Xiong,Jian and Gao, Guangwei and Wu, Xiaofu, and Latecki Jan Longin},
149 | journal={arXiv preprint arXiv:1905.02423},
150 | year={2019}
151 | }
152 | ```
153 |
154 | #### Tips
155 |
156 | - Limited by GPU resources, the project results need to be further improved...
157 | - It is recommended to pre-train Encoder on ImageNet and then Fine-turning Decoder part. The result will be better.
158 |
159 | #### Reference
160 |
161 | 1. [**Deep residual learning for image recognition**](https://arxiv.org/pdf/1512.03385.pdf)
162 | 2. [**Enet: A deep neural network architecture for real-time semantic segmentation**](https://arxiv.org/pdf/1606.02147.pdf)
163 | 3. [**Erfnet: Efficient residual factorized convnet for real-time semantic segmentation**](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8063438)
164 | 4. [**Shufflenet: An extremely efficient convolutional neural network for mobile devices**](https://arxiv.org/pdf/1707.01083.pdf)
165 |
166 |
170 |
171 | [python-image]: https://img.shields.io/badge/Python-3.x-ff69b4.svg
172 | [python-url]: https://www.python.org/
173 | [pytorch-image]: https://img.shields.io/badge/PyTorch-1.0-2BAF2B.svg
174 | [pytorch-url]: https://pytorch.org/
--------------------------------------------------------------------------------
/checkpoint/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/checkpoint/.gitkeep
--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
1 | #### Cityscapes dataset overview:
2 |
3 | 1. The data is **registered and downloaded **on the official website, https://www.cityscapes-dataset.com/. The official website can also view the indicators reached by everyone's neural network in the benchmarks.
4 |
5 | 2. data preprocessing and evaluation results code **download**: https://github.com/mcordts/cityscapesScripts
6 |
7 | The original image is stored in the **leftImg8bit** folder, and the finely labeled data is stored in the **gtFine** (gt: ground truth) folder. The training set consists of **2975** trains and the validation set is **500** sheets (val), all of which have corresponding labels. But the test set (test) only gave the original picture, no label, the official used to evaluate the code submitted by everyone (to prevent someone from using the test set training brush indicator). Therefore, in actual use, you can use the validation set to do the test. Coarse labeled data is stored in the **gtCoarse** (gt: ground truth) folder.
8 |
9 | Each image in the tag file corresponds to 4 files, where _gtFine_polygons.json stores the classes and corresponding regions (the boundary of the region is represented by the position of the vertices of the polygon); the value of _gtFine_labelIds.png is 0-33, different values Representing different classes, the correspondence between values and classes is defined in the code in cityscapesscripts/helpers/labels.py; _gtFine_instaceIds.png is an example split; _gtFine_color.png is for everyone to visualize, and the correspondence between different colors and categories is also Description in the labels.py file.
10 |
11 | #### Dataset Structure:
12 |
13 | ```
14 | ├── datasets # contains all datasets for the project
15 | | └── cityscapes # cityscapes dataset
16 | | | └── gtCoarse # Coarse cityscapes annotation
17 | | | └── gtFine # Fine cityscapes annotation
18 | | | └── leftImg8bit # cityscapes training image
19 | | | └── results #results move here for eval by evalPixelLevelSemanticLabeling.py
20 | | └── cityscapesscripts # cityscapes dataset label convert scripts!
21 | | | └── annotation #
22 | | | └── evalution #
23 | | | └── helps #
24 | | | └── preparation #
25 | | | └── viewer #
26 | | | └── __init__.py #
27 |
28 | ```
29 |
30 |
--------------------------------------------------------------------------------
/datasets/cityscapes/gtCoarse/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/datasets/cityscapes/gtCoarse/.gitkeep
--------------------------------------------------------------------------------
/datasets/cityscapes/gtFine/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/datasets/cityscapes/gtFine/.gitkeep
--------------------------------------------------------------------------------
/datasets/cityscapes/leftImg8bit/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/datasets/cityscapes/leftImg8bit/.gitkeep
--------------------------------------------------------------------------------
/datasets/cityscapes/results/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/datasets/cityscapes/results/.gitkeep
--------------------------------------------------------------------------------
/datasets/cityscapesscripts/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/datasets/cityscapesscripts/.gitkeep
--------------------------------------------------------------------------------
/imagenet-pretrain/README.md:
--------------------------------------------------------------------------------
1 | #### Model and ImageNet pretraining script
2 |
3 | This folder contains the script and model definition to pretrain LEDNet's encoder in ImageNet Data.
4 |
5 | The script is an adaptation from the code in [Pytorch Imagenet example](https://github.com/pytorch/examples/tree/master/imagenet). Please make sure that you have ImageNet dataset split in train and val folders before launching the script. Refer to that repository for instructions about usage and main.py options. Basic command:
6 |
7 | ```
8 | python main.py
9 | ```
10 |
11 |
12 | #### Third Party Project Reference
13 |
14 | - [ImageNet training in PyTorch](https://github.com/pytorch/examples/tree/master/imagenet)
15 |
16 | - [ShuffleNetv2 in PyTorch](https://github.com/Randl/ShuffleNetV2-pytorch)
17 |
18 |
--------------------------------------------------------------------------------
/imagenet-pretrain/lednet_imagenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 |
6 | def split(x):
7 | c = int(x.size()[1])
8 | c1 = round(c * 0.5)
9 | x1 = x[:, :c1, :, :].contiguous()
10 | x2 = x[:, c1:, :, :].contiguous()
11 |
12 | return x1, x2
13 |
14 | def channel_shuffle(x,groups):
15 | batchsize, num_channels, height, width = x.data.size()
16 |
17 | channels_per_group = num_channels // groups
18 |
19 | # reshape
20 | x = x.view(batchsize,groups,
21 | channels_per_group,height,width)
22 |
23 | x = torch.transpose(x,1,2).contiguous()
24 |
25 | # flatten
26 | x = x.view(batchsize,-1,height,width)
27 |
28 | return x
29 |
30 |
31 | class Conv2dBnRelu(nn.Module):
32 | def __init__(self,in_ch,out_ch,kernel_size=3,stride=1,padding=0,dilation=1,bias=True):
33 | super(Conv2dBnRelu,self).__init__()
34 |
35 | self.conv = nn.Sequential(
36 | nn.Conv2d(in_ch,out_ch,kernel_size,stride,padding,dilation=dilation,bias=bias),
37 | nn.BatchNorm2d(out_ch, eps=1e-3),
38 | nn.ReLU(inplace=True)
39 | )
40 |
41 | def forward(self, x):
42 | return self.conv(x)
43 |
44 |
45 | # after Concat -> BN, you also can use Dropout like SS_nbt_module may be make a good result!
46 | class DownsamplerBlock (nn.Module):
47 | def __init__(self, in_channel, out_channel):
48 | super(DownsamplerBlock,self).__init__()
49 |
50 | self.conv = nn.Conv2d(in_channel, out_channel-in_channel, (3, 3), stride=2, padding=1, bias=True)
51 | self.pool = nn.MaxPool2d(2, stride=2)
52 | self.bn = nn.BatchNorm2d(out_channel, eps=1e-3)
53 | self.relu = nn.ReLU(inplace=True)
54 |
55 | def forward(self, input):
56 | x1 = self.pool(input)
57 | x2 = self.conv(input)
58 |
59 | diffY = x2.size()[2] - x1.size()[2]
60 | diffX = x2.size()[3] - x1.size()[3]
61 |
62 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
63 | diffY // 2, diffY - diffY // 2])
64 |
65 | output = torch.cat([x2, x1], 1)
66 | output = self.bn(output)
67 | output = self.relu(output)
68 | return output
69 |
70 |
71 | class SS_nbt_module(nn.Module):
72 | def __init__(self, chann, dropprob, dilated):
73 | super().__init__()
74 |
75 | oup_inc = chann//2
76 |
77 | # dw
78 | self.conv3x1_1_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True)
79 |
80 | self.conv1x3_1_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True)
81 |
82 | self.bn1_l = nn.BatchNorm2d(oup_inc, eps=1e-03)
83 |
84 | self.conv3x1_2_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1))
85 |
86 | self.conv1x3_2_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated))
87 |
88 | self.bn2_l = nn.BatchNorm2d(oup_inc, eps=1e-03)
89 |
90 | # dw
91 | self.conv3x1_1_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True)
92 |
93 | self.conv1x3_1_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True)
94 |
95 | self.bn1_r = nn.BatchNorm2d(oup_inc, eps=1e-03)
96 |
97 | self.conv3x1_2_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1))
98 |
99 | self.conv1x3_2_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated))
100 |
101 | self.bn2_r = nn.BatchNorm2d(oup_inc, eps=1e-03)
102 |
103 | self.relu = nn.ReLU(inplace=True)
104 | self.dropout = nn.Dropout2d(dropprob)
105 |
106 | @staticmethod
107 | def _concat(x,out):
108 | return torch.cat((x,out),1)
109 |
110 | def forward(self, input):
111 |
112 | # x1 = input[:,:(input.shape[1]//2),:,:]
113 | # x2 = input[:,(input.shape[1]//2):,:,:]
114 | residual = input
115 | x1, x2 = split(input)
116 |
117 | output1 = self.conv3x1_1_l(x1)
118 | output1 = self.relu(output1)
119 | output1 = self.conv1x3_1_l(output1)
120 | output1 = self.bn1_l(output1)
121 | output1 = self.relu(output1)
122 |
123 | output1 = self.conv3x1_2_l(output1)
124 | output1 = self.relu(output1)
125 | output1 = self.conv1x3_2_l(output1)
126 | output1 = self.bn2_l(output1)
127 |
128 |
129 | output2 = self.conv1x3_1_r(x2)
130 | output2 = self.relu(output2)
131 | output2 = self.conv3x1_1_r(output2)
132 | output2 = self.bn1_r(output2)
133 | output2 = self.relu(output2)
134 |
135 | output2 = self.conv1x3_2_r(output2)
136 | output2 = self.relu(output2)
137 | output2 = self.conv3x1_2_r(output2)
138 | output2 = self.bn2_r(output2)
139 |
140 | if (self.dropout.p != 0):
141 | output1 = self.dropout(output1)
142 | output2 = self.dropout(output2)
143 |
144 | out = self._concat(output1,output2)
145 | out = F.relu(residual + out, inplace=True)
146 | return channel_shuffle(out, 2)
147 |
148 |
149 | class Encoder(nn.Module):
150 | def __init__(self):
151 | super().__init__()
152 | self.initial_block = DownsamplerBlock(3,32)
153 |
154 | self.layers = nn.ModuleList()
155 |
156 | for x in range(0, 3):
157 | self.layers.append(SS_nbt_module(32, 0.03, 1))
158 |
159 |
160 | self.layers.append(DownsamplerBlock(32,64))
161 |
162 |
163 | for x in range(0, 2):
164 | self.layers.append(SS_nbt_module(64, 0.03, 1))
165 |
166 | self.layers.append(DownsamplerBlock(64,128))
167 |
168 | for x in range(0, 1):
169 | self.layers.append(SS_nbt_module(128, 0.3, 1))
170 | self.layers.append(SS_nbt_module(128, 0.3, 2))
171 | self.layers.append(SS_nbt_module(128, 0.3, 5))
172 | self.layers.append(SS_nbt_module(128, 0.3, 9))
173 |
174 | for x in range(0, 1):
175 | self.layers.append(SS_nbt_module(128, 0.3, 2))
176 | self.layers.append(SS_nbt_module(128, 0.3, 5))
177 | self.layers.append(SS_nbt_module(128, 0.3, 9))
178 | self.layers.append(SS_nbt_module(128, 0.3, 17))
179 |
180 |
181 | def forward(self, input):
182 | output = self.initial_block(input)
183 |
184 | for layer in self.layers:
185 | output = layer(output)
186 |
187 | return output
188 |
189 |
190 | class Features(nn.Module):
191 | def __init__(self):
192 | super().__init__()
193 | self.encoder = Encoder()
194 | self.extralayer1 = nn.MaxPool2d(2, stride=2)
195 | self.extralayer2 = nn.AvgPool2d(14,1,0)
196 |
197 | def forward(self, input):
198 | output = self.encoder(input)
199 | output = self.extralayer1(output)
200 | output = self.extralayer2(output)
201 | return output
202 |
203 | class Classifier(nn.Module):
204 | def __init__(self, num_classes):
205 | super().__init__()
206 | self.linear = nn.Linear(128, num_classes)
207 |
208 | def forward(self, input):
209 | output = input.view(input.size(0), 128) #first is batch_size
210 | output = self.linear(output)
211 | return output
212 |
213 | class LEDNet(nn.Module):
214 | def __init__(self, num_classes): #use encoder to pass pretrained encoder
215 | super().__init__()
216 |
217 | self.features = Features()
218 | self.classifier = Classifier(num_classes)
219 |
220 | def forward(self, input):
221 | output = self.features(input)
222 | output = self.classifier(output)
223 | return output
224 |
225 |
--------------------------------------------------------------------------------
/imagenet-pretrain/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import time
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.parallel
9 | import torch.backends.cudnn as cudnn
10 | import torch.optim
11 | import torch.utils.data
12 | import torchvision.transforms as transforms
13 | import torchvision.datasets as datasets
14 | import torchvision.models as models
15 |
16 | from torch.optim import lr_scheduler
17 |
18 | from lednet_imagenet import LEDNet
19 |
20 | model_names = sorted(name for name in models.__dict__
21 | if name.islower() and not name.startswith("__")
22 | and callable(models.__dict__[name]))
23 |
24 |
25 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
26 | parser.add_argument('data', metavar='DIR',
27 | help='path to dataset')
28 | parser.add_argument('--arch', '-a', metavar='ARCH', default='lednet',
29 | #choices=model_names,
30 | help='model architecture: ' +
31 | ' | '.join(model_names) +
32 | ' (default: resnet18)')
33 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
34 | help='number of data loading workers (default: 4)')
35 | parser.add_argument('--epochs', default=150, type=int, metavar='N',
36 | help='number of total epochs to run')
37 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
38 | help='manual epoch number (useful on restarts)')
39 | parser.add_argument('-b', '--batch-size', default=256, type=int,
40 | metavar='N', help='mini-batch size (default: 256)')
41 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
42 | metavar='LR', help='initial learning rate')
43 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
44 | help='momentum')
45 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
46 | metavar='W', help='weight decay (default: 1e-4)')
47 | parser.add_argument('--print-freq', '-p', default=10, type=int,
48 | metavar='N', help='print frequency (default: 10)')
49 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
50 | help='path to latest checkpoint (default: none)')
51 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
52 | help='evaluate model on validation set')
53 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
54 | help='use pre-trained model')
55 |
56 | best_prec1 = 0
57 |
58 | def main():
59 | global args, best_prec1
60 | args = parser.parse_args()
61 |
62 | # create model
63 | if (args.arch == 'lednet'):
64 | model = LEDNet(1000)
65 | else:
66 | if args.pretrained:
67 | print("=> using pre-trained model '{}'".format(args.arch))
68 | model = models.__dict__[args.arch](pretrained=True)
69 | else:
70 | print("=> creating model '{}'".format(args.arch))
71 | model = models.__dict__[args.arch]()
72 |
73 | model = torch.nn.DataParallel(model).cuda()
74 |
75 | # define loss function (criterion) and optimizer
76 | criterion = nn.CrossEntropyLoss().cuda()
77 |
78 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
79 | momentum=args.momentum,
80 | weight_decay=args.weight_decay)
81 |
82 | # optionally resume from a checkpoint
83 | if args.resume:
84 | if os.path.isfile(args.resume):
85 | print("=> loading checkpoint '{}'".format(args.resume))
86 | checkpoint = torch.load(args.resume)
87 | args.start_epoch = checkpoint['epoch']
88 | best_prec1 = checkpoint['best_prec1']
89 | model.load_state_dict(checkpoint['state_dict'])
90 | optimizer.load_state_dict(checkpoint['optimizer'])
91 | print("=> loaded checkpoint '{}' (epoch {})"
92 | .format(args.resume, checkpoint['epoch']))
93 | else:
94 | print("=> no checkpoint found at '{}'".format(args.resume))
95 |
96 | cudnn.benchmark = True
97 |
98 | # Data loading code
99 | traindir = os.path.join(args.data, 'train')
100 | valdir = os.path.join(args.data, 'val')
101 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
102 | std=[0.229, 0.224, 0.225])
103 |
104 | train_loader = torch.utils.data.DataLoader(
105 | datasets.ImageFolder(traindir, transforms.Compose([
106 | #RemoveExif(),
107 | transforms.RandomResizedCrop(224),
108 | transforms.RandomHorizontalFlip(),
109 | transforms.ToTensor(),
110 | normalize,
111 | ])),
112 | batch_size=args.batch_size, shuffle=True,
113 | num_workers=args.workers, pin_memory=True)
114 |
115 | val_loader = torch.utils.data.DataLoader(
116 | datasets.ImageFolder(valdir, transforms.Compose([
117 | transforms.Resize(256),
118 | transforms.CenterCrop(224),
119 | transforms.ToTensor(),
120 | normalize,
121 | ])),
122 | batch_size=args.batch_size, shuffle=False,
123 | num_workers=args.workers, pin_memory=True)
124 |
125 | if args.evaluate:
126 | validate(val_loader, model, criterion)
127 | return
128 |
129 | for epoch in range(args.start_epoch, args.epochs):
130 | adjust_learning_rate(optimizer, epoch)
131 |
132 | # train for one epoch
133 | train(train_loader, model, criterion, optimizer, epoch)
134 |
135 | # evaluate on validation set
136 | prec1 = validate(val_loader, model, criterion)
137 |
138 | # remember best prec@1 and save checkpoint
139 | is_best = prec1 > best_prec1
140 | best_prec1 = max(prec1, best_prec1)
141 | save_checkpoint({
142 | 'epoch': epoch + 1,
143 | 'arch': args.arch,
144 | 'state_dict': model.state_dict(),
145 | 'best_prec1': best_prec1,
146 | 'optimizer' : optimizer.state_dict(),
147 | }, is_best)
148 |
149 | #scheduler.step(prec1, epoch) #decreases learning rate if prec1 plateaus
150 |
151 |
152 | def train(train_loader, model, criterion, optimizer, epoch):
153 | batch_time = AverageMeter()
154 | data_time = AverageMeter()
155 | losses = AverageMeter()
156 | top1 = AverageMeter()
157 | top5 = AverageMeter()
158 |
159 | # switch to train mode
160 | model.train()
161 |
162 | end = time.time()
163 | for i, (input, target) in enumerate(train_loader):
164 | # measure data loading time
165 | data_time.update(time.time() - end)
166 |
167 | target = target.cuda(async=True)
168 | input_var = torch.autograd.Variable(input)
169 | target_var = torch.autograd.Variable(target)
170 |
171 | # compute output
172 | output = model(input_var)
173 | loss = criterion(output, target_var)
174 |
175 | # measure accuracy and record loss
176 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
177 | losses.update(loss.data[0], input.size(0))
178 | top1.update(prec1[0], input.size(0))
179 | top5.update(prec5[0], input.size(0))
180 |
181 | # compute gradient and do SGD step
182 | optimizer.zero_grad()
183 | loss.backward()
184 | optimizer.step()
185 |
186 | # measure elapsed time
187 | batch_time.update(time.time() - end)
188 | end = time.time()
189 |
190 | if i % args.print_freq == 0:
191 | for param_group in optimizer.param_groups:
192 | lr = param_group['lr']
193 | print('Epoch: [{0}][{1}/{2}][lr:{lr:.6g}]\t'
194 | 'Time {batch_time.val:.3f} ({batch_time.avg:.2f}) / '
195 | 'Data {data_time.val:.3f} ({data_time.avg:.2f})\t'
196 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
197 | 'Prec@1 {top1.val:.2f} ({top1.avg:.2f})\t'
198 | 'Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(
199 | epoch, i, len(train_loader), batch_time=batch_time,
200 | data_time=data_time, loss=losses, top1=top1, top5=top5, lr=lr))
201 |
202 |
203 | def validate(val_loader, model, criterion):
204 | batch_time = AverageMeter()
205 | losses = AverageMeter()
206 | top1 = AverageMeter()
207 | top5 = AverageMeter()
208 |
209 | # switch to evaluate mode
210 | model.eval()
211 |
212 | end = time.time()
213 | for i, (input, target) in enumerate(val_loader):
214 | target = target.cuda(async=True)
215 | input_var = torch.autograd.Variable(input, volatile=True)
216 | target_var = torch.autograd.Variable(target, volatile=True)
217 |
218 | # compute output
219 | output = model(input_var)
220 | loss = criterion(output, target_var)
221 |
222 | # measure accuracy and record loss
223 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
224 | losses.update(loss.data[0], input.size(0))
225 | top1.update(prec1[0], input.size(0))
226 | top5.update(prec5[0], input.size(0))
227 |
228 | # measure elapsed time
229 | batch_time.update(time.time() - end)
230 | end = time.time()
231 |
232 | if i % args.print_freq == 0:
233 | print('Test: [{0}/{1}]\t'
234 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
235 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
236 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
237 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
238 | i, len(val_loader), batch_time=batch_time, loss=losses,
239 | top1=top1, top5=top5))
240 |
241 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
242 | .format(top1=top1, top5=top5))
243 |
244 | return top1.avg
245 |
246 |
247 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
248 | torch.save(state, filename)
249 | if is_best:
250 | shutil.copyfile(filename, 'model_best.pth.tar')
251 |
252 |
253 | class AverageMeter(object):
254 | """Computes and stores the average and current value"""
255 | def __init__(self):
256 | self.reset()
257 |
258 | def reset(self):
259 | self.val = 0
260 | self.avg = 0
261 | self.sum = 0
262 | self.count = 0
263 |
264 | def update(self, val, n=1):
265 | self.val = val
266 | self.sum += val * n
267 | self.count += n
268 | self.avg = self.sum / self.count
269 |
270 |
271 | def adjust_learning_rate(optimizer, epoch):
272 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
273 | lr = args.lr
274 | wd = 1e-4
275 | milestone = 15 #after epoch milestone, lr is reduced exponentially
276 | if epoch > milestone:
277 | lr = args.lr * (0.95 ** (epoch-milestone))
278 | wd = 0
279 | for param_group in optimizer.param_groups:
280 | param_group['lr'] = lr
281 | param_group['weight_decay'] = wd
282 |
283 | def accuracy(output, target, topk=(1,)):
284 | """Computes the precision@k for the specified values of k"""
285 | maxk = max(topk)
286 | batch_size = target.size(0)
287 |
288 | _, pred = output.topk(maxk, 1, True, True)
289 | pred = pred.t()
290 | correct = pred.eq(target.view(1, -1).expand_as(pred))
291 |
292 | res = []
293 | for k in topk:
294 | correct_k = correct[:k].view(-1).float().sum(0)
295 | res.append(correct_k.mul_(100.0 / batch_size))
296 | return res
297 |
298 |
299 | if __name__ == '__main__':
300 | main()
301 |
--------------------------------------------------------------------------------
/images/LEDNet_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/images/LEDNet_demo.png
--------------------------------------------------------------------------------
/images/LEDNet_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/images/LEDNet_overview.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch=0.4.1
2 | torchvision=0.2.1
3 | torchsummary==1.5.1
4 | visdom==0.1.8.4
5 |
--------------------------------------------------------------------------------
/save/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/save/.gitkeep
--------------------------------------------------------------------------------
/test/README.md:
--------------------------------------------------------------------------------
1 | #### Functions for evaluating/visualizing the network's output
2 |
3 | Currently there are 4 usable functions to evaluate stuff:
4 | - eval_cityscapes_color
5 | - eval_cityscapes_server
6 | - eval_iou
7 | - eval_forward_time
8 |
9 | #### eval_cityscapes_server.py
10 |
11 | This code can be used to produce segmentation of the Cityscapes images and convert the output indices to the original 'labelIds' so it can be evaluated using the scripts from Cityscapes dataset (evalPixelLevelSemanticLabeling.py) or uploaded to Cityscapes test server. By default it saves images in eval/save_results/ folder.
12 |
13 | **Options:** Specify the Cityscapes folder path with '--datadir' option. Select the cityscapes subset with '--subset' ('val', 'test', 'train' or 'demoSequence'). For other options check the bottom side of the file.
14 |
15 | **Examples:**
16 | ```
17 | python eval_cityscapes_server.py --datadir /xx/datasets/cityscapes/ --loadDir ../save/logs/ --loadWeights model_best.pth --loadModel lednet.py --subset val
18 | ```
19 |
20 | #### eval_cityscapes_color.py
21 |
22 | This code can be used to produce segmentation of the Cityscapes images in color for visualization purposes. By default it saves images in eval/save_color/ folder. You can also visualize results in visdom with --visualize flag.
23 |
24 | **Options:** Specify the Cityscapes folder path with '--datadir' option. Select the cityscapes subset with '--subset' ('val', 'test', 'train' or 'demoSequence'). For other options check the bottom side of the file.
25 |
26 | **Examples:**
27 |
28 | ```
29 | python eval_cityscapes_color.py --datadir /xx/datasets/cityscapes/ --loadDir ../save/logs/ --loadWeights model_best.pth --loadModel lednet.py --subset val
30 | ```
31 |
32 | #### eval_iou.py
33 |
34 | This code can be used to calculate the IoU (mean and per-class) in a subset of images with labels available, like Cityscapes val/train sets.
35 |
36 | **Options:** Specify the Cityscapes folder path with '--datadir' option. Select the cityscapes subset with '--subset' ('val' or 'train'). For other options check the bottom side of the file.
37 |
38 | **Examples:**
39 |
40 | ```
41 | python eval_iou.py --datadir /xx/datasets/cityscapes/ --loadDir ../save/logs/ --loadWeights model_best.pth --loadModel lednet.py --subset val
42 | ```
43 |
44 | #### eval_forward_time.py
45 | This function loads a model specified by '-m' and enters a loop to continuously estimate forward pass time (fwt) in the specified resolution.
46 |
47 | **Options:** Option '--width' specifies the width (default: 1024). Option '--height' specifies the height (default: 512). For other options check the bottom side of the file.
48 |
49 | **Examples:**
50 | ```
51 | python eval_forward_time.py --batch-size=6
52 | ```
53 |
54 |
55 |
56 |
--------------------------------------------------------------------------------
/test/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 |
4 | from PIL import Image
5 |
6 | from torch.utils.data import Dataset
7 |
8 | EXTENSIONS = ['.jpg', '.png']
9 |
10 | def load_image(file):
11 | return Image.open(file)
12 |
13 | def is_image(filename):
14 | return any(filename.endswith(ext) for ext in EXTENSIONS)
15 |
16 | def is_label(filename):
17 | return filename.endswith("_labelTrainIds.png")
18 |
19 | def image_path(root, basename, extension):
20 | return os.path.join(root, f'{basename}{extension}')
21 |
22 | def image_path_city(root, name):
23 | return os.path.join(root, f'{name}')
24 |
25 | def image_basename(filename):
26 | return os.path.basename(os.path.splitext(filename)[0])
27 |
28 | class VOC12(Dataset):
29 |
30 | def __init__(self, root, input_transform=None, target_transform=None):
31 | self.images_root = os.path.join(root, 'images')
32 | self.labels_root = os.path.join(root, 'labels')
33 |
34 | self.filenames = [image_basename(f)
35 | for f in os.listdir(self.labels_root) if is_image(f)]
36 | self.filenames.sort()
37 |
38 | self.input_transform = input_transform
39 | self.target_transform = target_transform
40 |
41 | def __getitem__(self, index):
42 | filename = self.filenames[index]
43 |
44 | with open(image_path(self.images_root, filename, '.jpg'), 'rb') as f:
45 | image = load_image(f).convert('RGB')
46 | with open(image_path(self.labels_root, filename, '.png'), 'rb') as f:
47 | label = load_image(f).convert('P')
48 |
49 | if self.input_transform is not None:
50 | image = self.input_transform(image)
51 | if self.target_transform is not None:
52 | label = self.target_transform(label)
53 |
54 | return image, label
55 |
56 | def __len__(self):
57 | return len(self.filenames)
58 |
59 |
60 | class cityscapes(Dataset):
61 |
62 | def __init__(self, root, input_transform=None, target_transform=None, subset='val'):
63 | self.images_root = os.path.join(root, 'leftImg8bit/' + subset)
64 | self.labels_root = os.path.join(root, 'gtFine/' + subset)
65 |
66 | self.filenames = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(self.images_root)) for f in fn if is_image(f)]
67 | self.filenames.sort()
68 |
69 | self.filenamesGt = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(self.labels_root)) for f in fn if is_label(f)]
70 | self.filenamesGt.sort()
71 |
72 | self.input_transform = input_transform
73 | self.target_transform = target_transform
74 |
75 | def __getitem__(self, index):
76 | filename = self.filenames[index]
77 | filenameGt = self.filenamesGt[index]
78 |
79 | #print(filename)
80 |
81 | with open(image_path_city(self.images_root, filename), 'rb') as f:
82 | image = load_image(f).convert('RGB')
83 | with open(image_path_city(self.labels_root, filenameGt), 'rb') as f:
84 | label = load_image(f).convert('P')
85 |
86 | if self.input_transform is not None:
87 | image = self.input_transform(image)
88 | if self.target_transform is not None:
89 | label = self.target_transform(label)
90 |
91 | return image, label, filename, filenameGt
92 |
93 | def __len__(self):
94 | return len(self.filenames)
95 |
96 |
--------------------------------------------------------------------------------
/test/eval_cityscapes_color.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import os
4 | import importlib
5 |
6 | from PIL import Image
7 | from argparse import ArgumentParser
8 |
9 | from torch.autograd import Variable
10 | from torch.utils.data import DataLoader
11 | from torchvision.transforms import Compose, CenterCrop, Normalize, Resize
12 | from torchvision.transforms import ToTensor, ToPILImage
13 |
14 | from dataset import cityscapes
15 |
16 |
17 |
18 | from lednet import Net
19 |
20 |
21 |
22 | from transform import Relabel, ToLabel, Colorize
23 |
24 | import visdom
25 |
26 |
27 | NUM_CHANNELS = 3
28 | NUM_CLASSES = 20
29 |
30 | image_transform = ToPILImage()
31 | input_transform_cityscapes = Compose([
32 | Resize((512,1024),Image.BILINEAR),
33 | ToTensor(),
34 | #Normalize([.485, .456, .406], [.229, .224, .225]),
35 | ])
36 | target_transform_cityscapes = Compose([
37 | Resize((512,1024),Image.NEAREST),
38 | ToLabel(),
39 | Relabel(255, 19), #ignore label to 19
40 | ])
41 |
42 | cityscapes_trainIds2labelIds = Compose([
43 | Relabel(19, 255),
44 | Relabel(18, 33),
45 | Relabel(17, 32),
46 | Relabel(16, 31),
47 | Relabel(15, 28),
48 | Relabel(14, 27),
49 | Relabel(13, 26),
50 | Relabel(12, 25),
51 | Relabel(11, 24),
52 | Relabel(10, 23),
53 | Relabel(9, 22),
54 | Relabel(8, 21),
55 | Relabel(7, 20),
56 | Relabel(6, 19),
57 | Relabel(5, 17),
58 | Relabel(4, 13),
59 | Relabel(3, 12),
60 | Relabel(2, 11),
61 | Relabel(1, 8),
62 | Relabel(0, 7),
63 | Relabel(255, 0),
64 | ToPILImage(),
65 | ])
66 |
67 | def main(args):
68 |
69 | modelpath = args.loadDir + args.loadModel
70 | weightspath = args.loadDir + args.loadWeights
71 |
72 | print ("Loading model: " + modelpath)
73 | print ("Loading weights: " + weightspath)
74 |
75 | model = Net(NUM_CLASSES)
76 |
77 | model = torch.nn.DataParallel(model)
78 | if (not args.cpu):
79 | model = model.cuda()
80 |
81 | #model.load_state_dict(torch.load(args.state))
82 | #model.load_state_dict(torch.load(weightspath)) #not working if missing key
83 |
84 | def load_my_state_dict(model, state_dict): #custom function to load model when not all dict elements
85 | own_state = model.state_dict()
86 | for name, param in state_dict.items():
87 | if name not in own_state:
88 | continue
89 | own_state[name].copy_(param)
90 | return model
91 |
92 | model = load_my_state_dict(model, torch.load(weightspath))
93 | print ("Model and weights LOADED successfully")
94 |
95 | model.eval()
96 |
97 | if(not os.path.exists(args.datadir)):
98 | print ("Error: datadir could not be loaded")
99 |
100 |
101 | loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset),
102 | num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)
103 |
104 | # For visualizer:
105 | # must launch in other window "python3.6 -m visdom.server -port 8097"
106 | # and access localhost:8097 to see it
107 | if (args.visualize):
108 | vis = visdom.Visdom()
109 |
110 | for step, (images, labels, filename, filenameGt) in enumerate(loader):
111 | if (not args.cpu):
112 | images = images.cuda()
113 | #labels = labels.cuda()
114 |
115 | inputs = Variable(images)
116 | #targets = Variable(labels)
117 | with torch.no_grad():
118 | outputs = model(inputs)
119 |
120 | label = outputs[0].max(0)[1].byte().cpu().data
121 | #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
122 | label_color = Colorize()(label.unsqueeze(0))
123 |
124 | filenameSave = "./save_color/" + filename[0].split("leftImg8bit/")[1]
125 | os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
126 | #image_transform(label.byte()).save(filenameSave)
127 | label_save = ToPILImage()(label_color)
128 | label_save.save(filenameSave)
129 |
130 | if (args.visualize):
131 | vis.image(label_color.numpy())
132 | print (step, filenameSave)
133 |
134 |
135 |
136 | if __name__ == '__main__':
137 | parser = ArgumentParser()
138 |
139 | parser.add_argument('--state')
140 |
141 |
142 | parser.add_argument('--loadDir',default="../save/logs/")
143 | parser.add_argument('--loadWeights', default="model_best.pth")
144 | parser.add_argument('--loadModel', default="lednet.py")
145 | parser.add_argument('--subset', default="val") #can be val, test, train, demoSequence
146 |
147 | parser.add_argument('--datadir', default=os.getenv("HOME") + "/datasets/cityscapes/")
148 | parser.add_argument('--num-workers', type=int, default=4)
149 | parser.add_argument('--batch-size', type=int, default=1)
150 | parser.add_argument('--cpu', action='store_true')
151 |
152 | parser.add_argument('--visualize', action='store_true')
153 | main(parser.parse_args())
154 |
--------------------------------------------------------------------------------
/test/eval_cityscapes_server.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import os
4 | import importlib
5 |
6 | from PIL import Image
7 | from argparse import ArgumentParser
8 |
9 | from torch.autograd import Variable
10 | from torch.utils.data import DataLoader
11 | from torchvision.transforms import Compose, CenterCrop, Normalize, Resize
12 | from torchvision.transforms import ToTensor, ToPILImage
13 |
14 | from dataset import cityscapes
15 |
16 | from lednet import Net
17 |
18 |
19 | from transform import Relabel, ToLabel, Colorize
20 |
21 |
22 | NUM_CHANNELS = 3
23 | NUM_CLASSES = 20
24 |
25 | image_transform = ToPILImage()
26 | input_transform_cityscapes = Compose([
27 | Resize(512),
28 | ToTensor(),
29 | #Normalize([.485, .456, .406], [.229, .224, .225]),
30 | ])
31 | target_transform_cityscapes = Compose([
32 | Resize(512),
33 | ToLabel(),
34 | Relabel(255, 19), #ignore label to 19
35 | ])
36 |
37 | cityscapes_trainIds2labelIds = Compose([
38 | Relabel(19, 255),
39 | Relabel(18, 33),
40 | Relabel(17, 32),
41 | Relabel(16, 31),
42 | Relabel(15, 28),
43 | Relabel(14, 27),
44 | Relabel(13, 26),
45 | Relabel(12, 25),
46 | Relabel(11, 24),
47 | Relabel(10, 23),
48 | Relabel(9, 22),
49 | Relabel(8, 21),
50 | Relabel(7, 20),
51 | Relabel(6, 19),
52 | Relabel(5, 17),
53 | Relabel(4, 13),
54 | Relabel(3, 12),
55 | Relabel(2, 11),
56 | Relabel(1, 8),
57 | Relabel(0, 7),
58 | Relabel(255, 0),
59 | ToPILImage(),
60 | Resize(1024, Image.NEAREST),
61 | ])
62 |
63 | def main(args):
64 |
65 | modelpath = args.loadDir + args.loadModel
66 | weightspath = args.loadDir + args.loadWeights
67 |
68 | print ("Loading model: " + modelpath)
69 | print ("Loading weights: " + weightspath)
70 |
71 |
72 | model = Net(NUM_CLASSES)
73 |
74 | model = torch.nn.DataParallel(model)
75 | if (not args.cpu):
76 | model = model.cuda()
77 |
78 | #model.load_state_dict(torch.load(args.state))
79 | #model.load_state_dict(torch.load(weightspath)) #not working if missing key
80 |
81 | def load_my_state_dict(model, state_dict): #custom function to load model when not all dict elements
82 | own_state = model.state_dict()
83 | for name, param in state_dict.items():
84 | if name not in own_state:
85 | continue
86 | own_state[name].copy_(param)
87 | return model
88 |
89 | model = load_my_state_dict(model, torch.load(weightspath))
90 | print ("Model and weights LOADED successfully")
91 |
92 | model.eval()
93 |
94 | if(not os.path.exists(args.datadir)):
95 | print ("Error: datadir could not be loaded")
96 |
97 |
98 | loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset),
99 | num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)
100 |
101 | for step, (images, labels, filename, filenameGt) in enumerate(loader):
102 | if (not args.cpu):
103 | images = images.cuda()
104 | #labels = labels.cuda()
105 |
106 | inputs = Variable(images)
107 | #targets = Variable(labels)
108 | with torch.no_grad():
109 | outputs = model(inputs)
110 |
111 | label = outputs[0].max(0)[1].byte().cpu().data
112 | label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
113 | #print (numpy.unique(label.numpy())) #debug
114 |
115 |
116 | filenameSave = "./save_results/" + filename[0].split("leftImg8bit/")[1]
117 |
118 | os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
119 | #image_transform(label.byte()).save(filenameSave)
120 | label_cityscapes.save(filenameSave)
121 |
122 | print (step, filenameSave)
123 |
124 |
125 |
126 | if __name__ == '__main__':
127 | parser = ArgumentParser()
128 |
129 | parser.add_argument('--state')
130 |
131 | parser.add_argument('--loadDir',default="../save/logs/")
132 | parser.add_argument('--loadWeights', default="model_best.pth")
133 | parser.add_argument('--loadModel', default="lednet.py")
134 | parser.add_argument('--subset', default="val") #can be val, test, train, demoSequence
135 | parser.add_argument('--datadir', default=os.getenv("HOME") + "/datasets/cityscapes/")
136 | parser.add_argument('--num-workers', type=int, default=4)
137 | parser.add_argument('--batch-size', type=int, default=1)
138 | parser.add_argument('--cpu', action='store_true')
139 |
140 | main(parser.parse_args())
141 |
--------------------------------------------------------------------------------
/test/eval_forward_time.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import time
5 |
6 | from PIL import Image
7 | from argparse import ArgumentParser
8 |
9 | from torch.autograd import Variable
10 |
11 |
12 | from lednet_no_bn import Net
13 |
14 | from transform import Relabel, ToLabel, Colorize
15 |
16 | import torch.backends.cudnn as cudnn
17 | cudnn.benchmark = True
18 |
19 | def main(args):
20 | model = Net(19)
21 |
22 | if (not args.cpu):
23 | model = model.cuda()#.half() #HALF seems to be doing slower for some reason
24 | #model = torch.nn.DataParallel(model).cuda()
25 |
26 | model.eval()
27 |
28 |
29 | images = torch.randn(args.batch_size, args.num_channels, args.height, args.width)
30 |
31 | if (not args.cpu):
32 | images = images.cuda()#.half()
33 |
34 | time_train = []
35 |
36 | i=0
37 |
38 | while(True):
39 | #for step, (images, labels, filename, filenameGt) in enumerate(loader):
40 |
41 | start_time = time.time()
42 |
43 | inputs = Variable(images)
44 | with torch.no_grad():
45 | outputs = model(inputs)
46 |
47 | #preds = outputs.cpu()
48 | if (not args.cpu):
49 | torch.cuda.synchronize() #wait for cuda to finish (cuda is asynchronous!)
50 |
51 | if i!=0: #first run always takes some time for setup
52 | fwt = time.time() - start_time
53 | time_train.append(fwt)
54 | print ("Forward time per img (b=%d): %.3f (Mean: %.3f)" % (args.batch_size, fwt/args.batch_size, sum(time_train) / len(time_train) / args.batch_size))
55 |
56 | time.sleep(1) #to avoid overheating the GPU too much
57 | i+=1
58 |
59 | if __name__ == '__main__':
60 | parser = ArgumentParser()
61 |
62 | parser.add_argument('--width', type=int, default=1024)
63 | parser.add_argument('--height', type=int, default=512)
64 | parser.add_argument('--num-channels', type=int, default=3)
65 | parser.add_argument('--batch-size', type=int, default=1)
66 | parser.add_argument('--cpu', action='store_true')
67 |
68 | main(parser.parse_args())
69 |
--------------------------------------------------------------------------------
/test/eval_iou.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import os
5 | import importlib
6 | import time
7 |
8 | from PIL import Image
9 | from argparse import ArgumentParser
10 |
11 | from torch.autograd import Variable
12 | from torch.utils.data import DataLoader
13 | from torchvision.transforms import Compose, CenterCrop, Normalize, Resize
14 | from torchvision.transforms import ToTensor, ToPILImage
15 |
16 | from dataset import cityscapes
17 |
18 | from lednet.py import Net
19 |
20 |
21 | from transform import Relabel, ToLabel, Colorize
22 | from iouEval import iouEval, getColorEntry
23 |
24 | NUM_CHANNELS = 3
25 | NUM_CLASSES = 20
26 |
27 | image_transform = ToPILImage()
28 | input_transform_cityscapes = Compose([
29 | Resize(512, Image.BILINEAR),
30 | ToTensor(),
31 | ])
32 | target_transform_cityscapes = Compose([
33 | Resize(512, Image.NEAREST),
34 | ToLabel(),
35 | Relabel(255, 19), #ignore label to 19
36 | ])
37 |
38 | def main(args):
39 |
40 | modelpath = args.loadDir + args.loadModel
41 | weightspath = args.loadDir + args.loadWeights
42 |
43 | print ("Loading model: " + modelpath)
44 | print ("Loading weights: " + weightspath)
45 |
46 | model = Net(NUM_CLASSES)
47 |
48 | #model = torch.nn.DataParallel(model)
49 | if (not args.cpu):
50 | model = torch.nn.DataParallel(model).cuda()
51 |
52 | def load_my_state_dict(model, state_dict): #custom function to load model when not all dict elements
53 | own_state = model.state_dict()
54 | for name, param in state_dict.items():
55 | if name not in own_state:
56 | if name.startswith("module."):
57 | own_state[name.split("module.")[-1]].copy_(param)
58 | else:
59 | print(name, " not loaded")
60 | continue
61 | else:
62 | own_state[name].copy_(param)
63 | return model
64 |
65 | model = load_my_state_dict(model, torch.load(weightspath, map_location=lambda storage, loc: storage))
66 | print ("Model and weights LOADED successfully")
67 |
68 |
69 | model.eval()
70 |
71 | if(not os.path.exists(args.datadir)):
72 | print ("Error: datadir could not be loaded")
73 |
74 |
75 | loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset), num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)
76 |
77 |
78 | iouEvalVal = iouEval(NUM_CLASSES)
79 |
80 | start = time.time()
81 |
82 | for step, (images, labels, filename, filenameGt) in enumerate(loader):
83 | if (not args.cpu):
84 | images = images.cuda()
85 | labels = labels.cuda()
86 |
87 | inputs = Variable(images)
88 | with torch.no_grad():
89 | outputs = model(inputs)
90 |
91 | iouEvalVal.addBatch(outputs.max(1)[1].unsqueeze(1).data, labels)
92 |
93 | filenameSave = filename[0].split("leftImg8bit/")[1]
94 |
95 | print (step, filenameSave)
96 |
97 |
98 | iouVal, iou_classes = iouEvalVal.getIoU()
99 |
100 | iou_classes_str = []
101 | for i in range(iou_classes.size(0)):
102 | iouStr = getColorEntry(iou_classes[i])+'{:0.2f}'.format(iou_classes[i]*100) + '\033[0m'
103 | iou_classes_str.append(iouStr)
104 |
105 | print("---------------------------------------")
106 | print("Took ", time.time()-start, "seconds")
107 | print("=======================================")
108 | #print("TOTAL IOU: ", iou * 100, "%")
109 | print("Per-Class IoU:")
110 | print(iou_classes_str[0], "Road")
111 | print(iou_classes_str[1], "sidewalk")
112 | print(iou_classes_str[2], "building")
113 | print(iou_classes_str[3], "wall")
114 | print(iou_classes_str[4], "fence")
115 | print(iou_classes_str[5], "pole")
116 | print(iou_classes_str[6], "traffic light")
117 | print(iou_classes_str[7], "traffic sign")
118 | print(iou_classes_str[8], "vegetation")
119 | print(iou_classes_str[9], "terrain")
120 | print(iou_classes_str[10], "sky")
121 | print(iou_classes_str[11], "person")
122 | print(iou_classes_str[12], "rider")
123 | print(iou_classes_str[13], "car")
124 | print(iou_classes_str[14], "truck")
125 | print(iou_classes_str[15], "bus")
126 | print(iou_classes_str[16], "train")
127 | print(iou_classes_str[17], "motorcycle")
128 | print(iou_classes_str[18], "bicycle")
129 | print("=======================================")
130 | iouStr = getColorEntry(iouVal)+'{:0.2f}'.format(iouVal*100) + '\033[0m'
131 | print ("MEAN IoU: ", iouStr, "%")
132 |
133 | if __name__ == '__main__':
134 | parser = ArgumentParser()
135 |
136 | parser.add_argument('--state')
137 |
138 | parser.add_argument('--loadDir',default="../trained_models/")
139 | parser.add_argument('--loadWeights', default="lednet_trained.pth")
140 | parser.add_argument('--loadModel', default="lednet.py")
141 | parser.add_argument('--subset', default="val") #can be val or train (must have labels)
142 | parser.add_argument('--datadir', default=os.getenv("HOME") + "/datasets/cityscapes/")
143 | parser.add_argument('--num-workers', type=int, default=4)
144 | parser.add_argument('--batch-size', type=int, default=1)
145 | parser.add_argument('--cpu', action='store_true')
146 |
147 | main(parser.parse_args())
148 |
--------------------------------------------------------------------------------
/test/iouEval.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class iouEval:
4 |
5 | def __init__(self, nClasses, ignoreIndex=19):
6 | self.nClasses = nClasses
7 | self.ignoreIndex = ignoreIndex if nClasses>ignoreIndex else -1 #if ignoreIndex is larger than nClasses, consider no ignoreIndex
8 | self.reset()
9 |
10 | def reset (self):
11 | classes = self.nClasses if self.ignoreIndex==-1 else self.nClasses-1
12 | self.tp = torch.zeros(classes).double()
13 | self.fp = torch.zeros(classes).double()
14 | self.fn = torch.zeros(classes).double()
15 |
16 | def addBatch(self, x, y): #x=preds, y=targets
17 | #sizes should be "batch_size x nClasses x H x W"
18 |
19 | #print ("X is cuda: ", x.is_cuda)
20 | #print ("Y is cuda: ", y.is_cuda)
21 |
22 | if (x.is_cuda or y.is_cuda):
23 | x = x.cuda()
24 | y = y.cuda()
25 |
26 | #if size is "batch_size x 1 x H x W" scatter to onehot
27 | if (x.size(1) == 1):
28 | x_onehot = torch.zeros(x.size(0), self.nClasses, x.size(2), x.size(3))
29 | if x.is_cuda:
30 | x_onehot = x_onehot.cuda()
31 | x_onehot.scatter_(1, x, 1).float()
32 | else:
33 | x_onehot = x.float()
34 |
35 | if (y.size(1) == 1):
36 | y_onehot = torch.zeros(y.size(0), self.nClasses, y.size(2), y.size(3))
37 | if y.is_cuda:
38 | y_onehot = y_onehot.cuda()
39 | y_onehot.scatter_(1, y, 1).float()
40 | else:
41 | y_onehot = y.float()
42 |
43 | if (self.ignoreIndex != -1):
44 | ignores = y_onehot[:,self.ignoreIndex].unsqueeze(1)
45 | x_onehot = x_onehot[:, :self.ignoreIndex]
46 | y_onehot = y_onehot[:, :self.ignoreIndex]
47 | else:
48 | ignores=0
49 |
50 | #print(type(x_onehot))
51 | #print(type(y_onehot))
52 | #print(x_onehot.size())
53 | #print(y_onehot.size())
54 |
55 | tpmult = x_onehot * y_onehot #times prediction and gt coincide is 1
56 | tp = torch.sum(torch.sum(torch.sum(tpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze()
57 | fpmult = x_onehot * (1-y_onehot-ignores) #times prediction says its that class and gt says its not (subtracting cases when its ignore label!)
58 | fp = torch.sum(torch.sum(torch.sum(fpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze()
59 | fnmult = (1-x_onehot) * (y_onehot) #times prediction says its not that class and gt says it is
60 | fn = torch.sum(torch.sum(torch.sum(fnmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze()
61 |
62 | self.tp += tp.double().cpu()
63 | self.fp += fp.double().cpu()
64 | self.fn += fn.double().cpu()
65 |
66 | def getIoU(self):
67 | num = self.tp
68 | den = self.tp + self.fp + self.fn + 1e-15
69 | iou = num / den
70 | return torch.mean(iou), iou #returns "iou mean", "iou per class"
71 |
72 | # Class for colors
73 | class colors:
74 | RED = '\033[31;1m'
75 | GREEN = '\033[32;1m'
76 | YELLOW = '\033[33;1m'
77 | BLUE = '\033[34;1m'
78 | MAGENTA = '\033[35;1m'
79 | CYAN = '\033[36;1m'
80 | BOLD = '\033[1m'
81 | UNDERLINE = '\033[4m'
82 | ENDC = '\033[0m'
83 |
84 | # Colored value output if colorized flag is activated.
85 | def getColorEntry(val):
86 | if not isinstance(val, float):
87 | return colors.ENDC
88 | if (val < .20):
89 | return colors.RED
90 | elif (val < .40):
91 | return colors.YELLOW
92 | elif (val < .60):
93 | return colors.BLUE
94 | elif (val < .80):
95 | return colors.CYAN
96 | else:
97 | return colors.GREEN
98 |
99 |
--------------------------------------------------------------------------------
/test/lednet_no_bn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 | from torch.nn.functional import interpolate as interpolate
6 |
7 | def split(x):
8 | c = int(x.size()[1])
9 | c1 = round(c * 0.5)
10 | x1 = x[:, :c1, :, :].contiguous()
11 | x2 = x[:, c1:, :, :].contiguous()
12 |
13 | return x1, x2
14 |
15 | def channel_shuffle(x,groups):
16 | batchsize, num_channels, height, width = x.data.size()
17 |
18 | channels_per_group = num_channels // groups
19 |
20 | #reshape
21 | x = x.view(batchsize,groups,
22 | channels_per_group,height,width)
23 |
24 | x = torch.transpose(x,1,2).contiguous()
25 |
26 | #flatten
27 | x = x.view(batchsize,-1,height,width)
28 |
29 | return x
30 |
31 |
32 | class Conv2dBnRelu(nn.Module):
33 | def __init__(self,in_ch,out_ch,kernel_size=3,stride=1,padding=0,dilation=1,bias=True):
34 | super(Conv2dBnRelu,self).__init__()
35 |
36 | self.conv = nn.Sequential(
37 | nn.Conv2d(in_ch,out_ch,kernel_size,stride,padding,dilation=dilation,bias=bias),
38 | # nn.BatchNorm2d(out_ch, eps=1e-3),
39 | nn.ReLU(inplace=True)
40 | )
41 |
42 | def forward(self, x):
43 | return self.conv(x)
44 |
45 |
46 |
47 | class DownsamplerBlock (nn.Module):
48 | def __init__(self, in_channel, out_channel):
49 | super(DownsamplerBlock,self).__init__()
50 |
51 | self.conv = nn.Conv2d(in_channel, out_channel-in_channel, (3, 3), stride=2, padding=1, bias=True)
52 | self.pool = nn.MaxPool2d(2, stride=2)
53 | self.bn = nn.BatchNorm2d(out_channel, eps=1e-3)
54 | self.relu = nn.ReLU(inplace=True)
55 |
56 | def forward(self, input):
57 | x1 = self.pool(input)
58 | x2 = self.conv(input)
59 |
60 | diffY = x2.size()[2] - x1.size()[2]
61 | diffX = x2.size()[3] - x1.size()[3]
62 |
63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
64 | diffY // 2, diffY - diffY // 2])
65 |
66 | output = torch.cat([x2, x1], 1)
67 | output = self.bn(output)
68 | output = self.relu(output)
69 | return output
70 |
71 |
72 | class SS_nbt_module(nn.Module):
73 | def __init__(self, chann, dropprob, dilated):
74 | super(SS_nbt_module,self).__init__()
75 |
76 | oup_inc = chann//2
77 |
78 | #
79 | self.conv3x1_1_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True)
80 |
81 | self.conv1x3_1_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True)
82 |
83 | self.bn1_l = nn.BatchNorm2d(oup_inc, eps=1e-03)
84 |
85 | self.conv3x1_2_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1))
86 |
87 | self.conv1x3_2_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated))
88 |
89 | self.bn2_l = nn.BatchNorm2d(oup_inc, eps=1e-03)
90 |
91 | #
92 | self.conv3x1_1_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True)
93 |
94 | self.conv1x3_1_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True)
95 |
96 | self.bn1_r = nn.BatchNorm2d(oup_inc, eps=1e-03)
97 |
98 | self.conv3x1_2_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1))
99 |
100 | self.conv1x3_2_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated))
101 |
102 | self.bn2_r = nn.BatchNorm2d(oup_inc, eps=1e-03)
103 |
104 | self.relu = nn.ReLU(inplace=True)
105 | self.dropout = nn.Dropout2d(dropprob)
106 |
107 | @staticmethod
108 | def _concat(x,out):
109 | return torch.cat((x,out),1)
110 |
111 | def forward(self, input):
112 |
113 | # x1 = input[:,:(input.shape[1]//2),:,:]
114 | # x2 = input[:,(input.shape[1]//2):,:,:]
115 | residual = input
116 | x1, x2 = split(input)
117 |
118 | output1 = self.conv3x1_1_l(x1)
119 | output1 = self.relu(output1)
120 | output1 = self.conv1x3_1_l(output1)
121 | # output1 = self.bn1_l(output1)
122 | output1 = self.relu(output1)
123 |
124 | output1 = self.conv3x1_2_l(output1)
125 | output1 = self.relu(output1)
126 | output1 = self.conv1x3_2_l(output1)
127 | # output1 = self.bn2_l(output1)
128 |
129 |
130 | output2 = self.conv1x3_1_r(x2)
131 | output2 = self.relu(output2)
132 | output2 = self.conv3x1_1_r(output2)
133 | # output2 = self.bn1_r(output2)
134 | output2 = self.relu(output2)
135 |
136 | output2 = self.conv1x3_2_r(output2)
137 | output2 = self.relu(output2)
138 | output2 = self.conv3x1_2_r(output2)
139 | # output2 = self.bn2_r(output2)
140 |
141 | # if (self.dropout.p != 0):
142 | #output1 = self.dropout(output1)
143 | #output2 = self.dropout(output2)
144 |
145 | out = self._concat(output1,output2)
146 | out = F.relu(residual + out, inplace=True)
147 | return channel_shuffle(out, 2)
148 |
149 | class Encoder(nn.Module):
150 | def __init__(self, num_classes):
151 | super().__init__()
152 | self.initial_block = DownsamplerBlock(3,32)
153 |
154 |
155 | self.layers = nn.ModuleList()
156 |
157 | for x in range(0, 3):
158 | self.layers.append(SS_nbt_module(32, 0.03, 1))
159 |
160 |
161 | self.layers.append(DownsamplerBlock(32,64))
162 |
163 |
164 | for x in range(0, 2):
165 | self.layers.append(SS_nbt_module(64, 0.03, 1))
166 |
167 | self.layers.append(DownsamplerBlock(64,128))
168 |
169 | for x in range(0, 1):
170 | self.layers.append(SS_nbt_module(128, 0.3, 1))
171 | self.layers.append(SS_nbt_module(128, 0.3, 2))
172 | self.layers.append(SS_nbt_module(128, 0.3, 5))
173 | self.layers.append(SS_nbt_module(128, 0.3, 9))
174 |
175 | for x in range(0, 1):
176 | self.layers.append(SS_nbt_module(128, 0.3, 2))
177 | self.layers.append(SS_nbt_module(128, 0.3, 5))
178 | self.layers.append(SS_nbt_module(128, 0.3, 9))
179 | self.layers.append(SS_nbt_module(128, 0.3, 17))
180 |
181 |
182 | # Only in encoder mode:
183 | self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True)
184 |
185 | def forward(self, input, predict=False):
186 |
187 | output = self.initial_block(input)
188 |
189 | for layer in self.layers:
190 | output = layer(output)
191 |
192 | if predict:
193 | output = self.output_conv(output)
194 |
195 | return output
196 |
197 | class Interpolate(nn.Module):
198 | def __init__(self,size,mode):
199 | super(Interpolate,self).__init__()
200 |
201 | self.interp = nn.functional.interpolate
202 | self.size = size
203 | self.mode = mode
204 | def forward(self,x):
205 | x = self.interp(x,size=self.size,mode=self.mode,align_corners=True)
206 | return x
207 |
208 |
209 | class APN_Module(nn.Module):
210 | def __init__(self, in_ch, out_ch):
211 | super(APN_Module, self).__init__()
212 | # global pooling branch
213 | self.branch1 = nn.Sequential(
214 | nn.AdaptiveAvgPool2d(1),
215 | Conv2dBnRelu(in_ch, out_ch, kernel_size=1, stride=1, padding=0)
216 | )
217 |
218 | # midddle branch
219 | self.mid = nn.Sequential(
220 | Conv2dBnRelu(in_ch, out_ch, kernel_size=1, stride=1, padding=0)
221 | )
222 |
223 | self.down1 = Conv2dBnRelu(in_ch, 1, kernel_size=7, stride=2, padding=3)
224 |
225 | self.down2 = Conv2dBnRelu(1, 1, kernel_size=5, stride=2, padding=2)
226 |
227 | self.down3 = nn.Sequential(
228 | Conv2dBnRelu(1, 1, kernel_size=3, stride=2, padding=1),
229 | Conv2dBnRelu(1, 1, kernel_size=3, stride=1, padding=1)
230 | )
231 |
232 | self.conv2 = Conv2dBnRelu(1, 1, kernel_size=5, stride=1, padding=2)
233 | self.conv1 = Conv2dBnRelu(1, 1, kernel_size=7, stride=1, padding=3)
234 |
235 | def forward(self, x):
236 |
237 | h = x.size()[2]
238 | w = x.size()[3]
239 |
240 | b1 = self.branch1(x)
241 | # b1 = Interpolate(size=(h, w), mode="bilinear")(b1)
242 | b1= interpolate(b1, size=(h, w), mode="bilinear", align_corners=True)
243 |
244 | mid = self.mid(x)
245 |
246 | x1 = self.down1(x)
247 | x2 = self.down2(x1)
248 | x3 = self.down3(x2)
249 | # x3 = Interpolate(size=(h // 4, w // 4), mode="bilinear")(x3)
250 | x3= interpolate(x3, size=(h // 4, w // 4), mode="bilinear", align_corners=True)
251 | x2 = self.conv2(x2)
252 | x = x2 + x3
253 | # x = Interpolate(size=(h // 2, w // 2), mode="bilinear")(x)
254 | x= interpolate(x, size=(h // 2, w // 2), mode="bilinear", align_corners=True)
255 |
256 | x1 = self.conv1(x1)
257 | x = x + x1
258 | # x = Interpolate(size=(h, w), mode="bilinear")(x)
259 | x= interpolate(x, size=(h, w), mode="bilinear", align_corners=True)
260 |
261 | x = torch.mul(x, mid)
262 |
263 | x = x + b1
264 |
265 |
266 | return x
267 |
268 |
269 |
270 |
271 | class Decoder (nn.Module):
272 | def __init__(self, num_classes):
273 | super().__init__()
274 |
275 | self.apn = APN_Module(in_ch=128,out_ch=20)
276 | # self.upsample = Interpolate(size=(512, 1024), mode="bilinear")
277 | # self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=4, stride=2, padding=1, output_padding=0, bias=True)
278 | # self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True)
279 | # self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True)
280 |
281 | def forward(self, input):
282 |
283 | output = self.apn(input)
284 | out = interpolate(output, size=(512, 1024), mode="bilinear", align_corners=True)
285 | # out = self.upsample(output)
286 | return out
287 |
288 |
289 | # LEDNet
290 | class Net(nn.Module):
291 | def __init__(self, num_classes, encoder=None):
292 | super().__init__()
293 |
294 | if (encoder == None):
295 | self.encoder = Encoder(num_classes)
296 | else:
297 | self.encoder = encoder
298 | self.decoder = Decoder(num_classes)
299 |
300 | def forward(self, input, only_encode=False):
301 | if only_encode:
302 | return self.encoder.forward(input, predict=True)
303 | else:
304 | output = self.encoder(input)
305 | return self.decoder.forward(output)
306 |
307 |
308 |
--------------------------------------------------------------------------------
/test/transform.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from PIL import Image
5 |
6 | def colormap_cityscapes(n):
7 | cmap=np.zeros([n, 3]).astype(np.uint8)
8 | cmap[0,:] = np.array([128, 64,128])
9 | cmap[1,:] = np.array([244, 35,232])
10 | cmap[2,:] = np.array([ 70, 70, 70])
11 | cmap[3,:] = np.array([ 102,102,156])
12 | cmap[4,:] = np.array([ 190,153,153])
13 | cmap[5,:] = np.array([ 153,153,153])
14 |
15 | cmap[6,:] = np.array([ 250,170, 30])
16 | cmap[7,:] = np.array([ 220,220, 0])
17 | cmap[8,:] = np.array([ 107,142, 35])
18 | cmap[9,:] = np.array([ 152,251,152])
19 | cmap[10,:] = np.array([ 70,130,180])
20 |
21 | cmap[11,:] = np.array([ 220, 20, 60])
22 | cmap[12,:] = np.array([ 255, 0, 0])
23 | cmap[13,:] = np.array([ 0, 0,142])
24 | cmap[14,:] = np.array([ 0, 0, 70])
25 | cmap[15,:] = np.array([ 0, 60,100])
26 |
27 | cmap[16,:] = np.array([ 0, 80,100])
28 | cmap[17,:] = np.array([ 0, 0,230])
29 | cmap[18,:] = np.array([ 119, 11, 32])
30 | cmap[19,:] = np.array([ 0, 0, 0])
31 |
32 | return cmap
33 |
34 |
35 | def colormap(n):
36 | cmap=np.zeros([n, 3]).astype(np.uint8)
37 |
38 | for i in np.arange(n):
39 | r, g, b = np.zeros(3)
40 |
41 | for j in np.arange(8):
42 | r = r + (1<<(7-j))*((i&(1<<(3*j))) >> (3*j))
43 | g = g + (1<<(7-j))*((i&(1<<(3*j+1))) >> (3*j+1))
44 | b = b + (1<<(7-j))*((i&(1<<(3*j+2))) >> (3*j+2))
45 |
46 | cmap[i,:] = np.array([r, g, b])
47 |
48 | return cmap
49 |
50 | class Relabel:
51 |
52 | def __init__(self, olabel, nlabel):
53 | self.olabel = olabel
54 | self.nlabel = nlabel
55 |
56 | def __call__(self, tensor):
57 | assert isinstance(tensor, torch.LongTensor) or isinstance(tensor, torch.ByteTensor) , 'tensor needs to be LongTensor'
58 | tensor[tensor == self.olabel] = self.nlabel
59 | return tensor
60 |
61 |
62 | class ToLabel:
63 |
64 | def __call__(self, image):
65 | return torch.from_numpy(np.array(image)).long().unsqueeze(0)
66 |
67 |
68 | class Colorize:
69 |
70 | def __init__(self, n=22):
71 | #self.cmap = colormap(256)
72 | self.cmap = colormap_cityscapes(256)
73 | self.cmap[n] = self.cmap[-1]
74 | self.cmap = torch.from_numpy(self.cmap[:n])
75 |
76 | def __call__(self, gray_image):
77 | size = gray_image.size()
78 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
79 |
80 | #for label in range(1, len(self.cmap)):
81 | for label in range(0, len(self.cmap)):
82 | mask = gray_image[0] == label
83 |
84 | color_image[0][mask] = self.cmap[label][0]
85 | color_image[1][mask] = self.cmap[label][1]
86 | color_image[2][mask] = self.cmap[label][2]
87 |
88 | return color_image
89 |
--------------------------------------------------------------------------------
/train/README.md:
--------------------------------------------------------------------------------
1 | #### Training LEDNet in Pytorch
2 |
3 | PyTorch code for training LEDNet model on Cityscapes. The code was based initially on the code from [bodokaiser/piwise](https://github.com/bodokaiser/piwise), adapted with several custom added modifications and tweaks. Some of them are:
4 | - Load cityscapes dataset
5 | - LEDNet model definition
6 | - Calculate IoU on each epoch during training
7 | - Save snapshots and best model during training
8 | - Save additional output files useful for checking results (see below "Output files...")
9 | - Resume training from checkpoint (use "--resume" flag in the command)
10 |
11 | #### Options
12 | For all options and defaults please see the bottom of the "main.py" file. Required ones are --savedir (name for creating a new folder with all the outputs of the training) and --datadir (path to cityscapes directory).
13 |
14 | #### Example commands
15 | Train encoder with 300+ epochs and batch=5 and then train decoder (decoder training starts after encoder training): for example
16 | ```
17 | python main.py --savedir logs --datadir /home/datasets/cityscapes/ --num-epochs 300 --batch-size 5 ...
18 | ```
19 |
20 | Each training will create a new folder in the "LEDNet_master/save/" directory named with the parameter --savedir and the following files:
21 | * **{model}.py**: copy of the model file used (default lednet.py).
22 | * **model.txt**: Plain text that displays the model's layers
23 | * **model_best.pth**: saved weights of the epoch that achieved best val accuracy.
24 | * **model_best.pth.tar**: Same parameters as "checkpoint.pth.tar" but for the epoch with best val accuracy.
25 | * **opts.txt**: Plain text file containing the options used for this training
26 | * **automated_log.txt**: Plain text file that contains in columns the following info of each epoch {Epoch, Train-loss,Test-loss,Train-IoU,Test-IoU, learningRate}. Can be used to plot using Gnuplot or Excel or Matplotlib.
27 | * **best.txt**: Plain text file containing a line with the best IoU achieved during training and its epoch.
28 | * **checkpoint.pth.tar**: bundle file that contains the checkpoint of the last trained epoch, contains the following elements: 'epoch' (epoch number as int), 'arch' (net definition as a string), 'state_dict' (saved weights dictionary loadable by pytorch), 'best_acc' (best achieved accuracy as float), 'optimizer' (saved optimizer parameters).
29 |
30 | NOTE: Encoder trainings have an added "_encoder" tag to each file's name.
31 |
32 | #### IoU display during training
33 |
34 | NEW: In previous code, IoU was calculated using a port of the cityscapes scripts, but new code has been added in "iouEval.py" to make it class-general, non-dependable on other code, and much faster (using cuda)
35 |
36 | By default, only Validation IoU is calculated for faster training (can be changed in options)
37 |
38 | #### Visualization
39 | If you want to visualize the outputs during training add the "--visualize" flag and open an extra tab with:
40 | ```
41 | python -m visdom.server -port 8097
42 | ```
43 | The plots will be available using the browser in http://localhost.com:8097
44 |
45 | #### Multi-GPU
46 | If you wish to specify which GPUs to use, use the CUDA_VISIBLE_DEVICES command:
47 | ```
48 | CUDA_VISIBLE_DEVICES=0 python main.py ...
49 | CUDA_VISIBLE_DEVICES=0,1 python main.py ...
50 | ```
51 |
52 |
53 |
--------------------------------------------------------------------------------
/train/lednet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 | from torch.nn.functional import interpolate as interpolate
6 |
7 |
8 | def split(x):
9 | c = int(x.size()[1])
10 | c1 = round(c * 0.5)
11 | x1 = x[:, :c1, :, :].contiguous()
12 | x2 = x[:, c1:, :, :].contiguous()
13 |
14 | return x1, x2
15 |
16 | def channel_shuffle(x,groups):
17 | batchsize, num_channels, height, width = x.data.size()
18 |
19 | channels_per_group = num_channels // groups
20 |
21 | # reshape
22 | x = x.view(batchsize,groups,
23 | channels_per_group,height,width)
24 |
25 | x = torch.transpose(x,1,2).contiguous()
26 |
27 | # flatten
28 | x = x.view(batchsize,-1,height,width)
29 |
30 | return x
31 |
32 |
33 | class Conv2dBnRelu(nn.Module):
34 | def __init__(self,in_ch,out_ch,kernel_size=3,stride=1,padding=0,dilation=1,bias=True):
35 | super(Conv2dBnRelu,self).__init__()
36 |
37 | self.conv = nn.Sequential(
38 | nn.Conv2d(in_ch,out_ch,kernel_size,stride,padding,dilation=dilation,bias=bias),
39 | nn.BatchNorm2d(out_ch, eps=1e-3),
40 | nn.ReLU(inplace=True)
41 | )
42 |
43 | def forward(self, x):
44 | return self.conv(x)
45 |
46 |
47 | # after Concat -> BN, you also can use Dropout like SS_nbt_module may be make a good result!
48 | class DownsamplerBlock (nn.Module):
49 | def __init__(self, in_channel, out_channel):
50 | super(DownsamplerBlock,self).__init__()
51 |
52 | self.conv = nn.Conv2d(in_channel, out_channel-in_channel, (3, 3), stride=2, padding=1, bias=True)
53 | self.pool = nn.MaxPool2d(2, stride=2)
54 | self.bn = nn.BatchNorm2d(out_channel, eps=1e-3)
55 | self.relu = nn.ReLU(inplace=True)
56 |
57 | def forward(self, input):
58 | x1 = self.pool(input)
59 | x2 = self.conv(input)
60 |
61 | diffY = x2.size()[2] - x1.size()[2]
62 | diffX = x2.size()[3] - x1.size()[3]
63 |
64 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
65 | diffY // 2, diffY - diffY // 2])
66 |
67 | output = torch.cat([x2, x1], 1)
68 | output = self.bn(output)
69 | output = self.relu(output)
70 | return output
71 |
72 |
73 | class SS_nbt_module(nn.Module):
74 | def __init__(self, chann, dropprob, dilated):
75 | super().__init__()
76 |
77 | oup_inc = chann//2
78 |
79 | # dw
80 | self.conv3x1_1_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True)
81 |
82 | self.conv1x3_1_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True)
83 |
84 | self.bn1_l = nn.BatchNorm2d(oup_inc, eps=1e-03)
85 |
86 | self.conv3x1_2_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1))
87 |
88 | self.conv1x3_2_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated))
89 |
90 | self.bn2_l = nn.BatchNorm2d(oup_inc, eps=1e-03)
91 |
92 | # dw
93 | self.conv3x1_1_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True)
94 |
95 | self.conv1x3_1_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True)
96 |
97 | self.bn1_r = nn.BatchNorm2d(oup_inc, eps=1e-03)
98 |
99 | self.conv3x1_2_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1))
100 |
101 | self.conv1x3_2_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated))
102 |
103 | self.bn2_r = nn.BatchNorm2d(oup_inc, eps=1e-03)
104 |
105 | self.relu = nn.ReLU(inplace=True)
106 | self.dropout = nn.Dropout2d(dropprob)
107 |
108 | @staticmethod
109 | def _concat(x,out):
110 | return torch.cat((x,out),1)
111 |
112 | def forward(self, input):
113 |
114 | # x1 = input[:,:(input.shape[1]//2),:,:]
115 | # x2 = input[:,(input.shape[1]//2):,:,:]
116 | residual = input
117 | x1, x2 = split(input)
118 |
119 | output1 = self.conv3x1_1_l(x1)
120 | output1 = self.relu(output1)
121 | output1 = self.conv1x3_1_l(output1)
122 | output1 = self.bn1_l(output1)
123 | output1 = self.relu(output1)
124 |
125 | output1 = self.conv3x1_2_l(output1)
126 | output1 = self.relu(output1)
127 | output1 = self.conv1x3_2_l(output1)
128 | output1 = self.bn2_l(output1)
129 |
130 |
131 | output2 = self.conv1x3_1_r(x2)
132 | output2 = self.relu(output2)
133 | output2 = self.conv3x1_1_r(output2)
134 | output2 = self.bn1_r(output2)
135 | output2 = self.relu(output2)
136 |
137 | output2 = self.conv1x3_2_r(output2)
138 | output2 = self.relu(output2)
139 | output2 = self.conv3x1_2_r(output2)
140 | output2 = self.bn2_r(output2)
141 |
142 | if (self.dropout.p != 0):
143 | output1 = self.dropout(output1)
144 | output2 = self.dropout(output2)
145 |
146 | out = self._concat(output1,output2)
147 | out = F.relu(residual + out, inplace=True)
148 | return channel_shuffle(out,2)
149 |
150 |
151 |
152 | class Encoder(nn.Module):
153 | def __init__(self, num_classes):
154 | super().__init__()
155 |
156 | self.initial_block = DownsamplerBlock(3,32)
157 |
158 | self.layers = nn.ModuleList()
159 |
160 | for x in range(0, 3):
161 | self.layers.append(SS_nbt_module(32, 0.03, 1))
162 |
163 |
164 | self.layers.append(DownsamplerBlock(32,64))
165 |
166 |
167 | for x in range(0, 2):
168 | self.layers.append(SS_nbt_module(64, 0.03, 1))
169 |
170 | self.layers.append(DownsamplerBlock(64,128))
171 |
172 | for x in range(0, 1):
173 | self.layers.append(SS_nbt_module(128, 0.3, 1))
174 | self.layers.append(SS_nbt_module(128, 0.3, 2))
175 | self.layers.append(SS_nbt_module(128, 0.3, 5))
176 | self.layers.append(SS_nbt_module(128, 0.3, 9))
177 |
178 | for x in range(0, 1):
179 | self.layers.append(SS_nbt_module(128, 0.3, 2))
180 | self.layers.append(SS_nbt_module(128, 0.3, 5))
181 | self.layers.append(SS_nbt_module(128, 0.3, 9))
182 | self.layers.append(SS_nbt_module(128, 0.3, 17))
183 |
184 |
185 | # Only in encoder mode:
186 | self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True)
187 |
188 | def forward(self, input, predict=False):
189 |
190 | output = self.initial_block(input)
191 |
192 | for layer in self.layers:
193 | output = layer(output)
194 |
195 | if predict:
196 | output = self.output_conv(output)
197 |
198 | return output
199 |
200 | class Interpolate(nn.Module):
201 | def __init__(self,size,mode):
202 | super(Interpolate,self).__init__()
203 |
204 | self.interp = nn.functional.interpolate
205 | self.size = size
206 | self.mode = mode
207 | def forward(self,x):
208 | x = self.interp(x,size=self.size,mode=self.mode,align_corners=True)
209 | return x
210 |
211 |
212 | class APN_Module(nn.Module):
213 | def __init__(self, in_ch, out_ch):
214 | super(APN_Module, self).__init__()
215 | # global pooling branch
216 | self.branch1 = nn.Sequential(
217 | nn.AdaptiveAvgPool2d(1),
218 | Conv2dBnRelu(in_ch, out_ch, kernel_size=1, stride=1, padding=0)
219 | )
220 | # midddle branch
221 | self.mid = nn.Sequential(
222 | Conv2dBnRelu(in_ch, out_ch, kernel_size=1, stride=1, padding=0)
223 | )
224 | self.down1 = Conv2dBnRelu(in_ch, 1, kernel_size=7, stride=2, padding=3)
225 |
226 | self.down2 = Conv2dBnRelu(1, 1, kernel_size=5, stride=2, padding=2)
227 |
228 | self.down3 = nn.Sequential(
229 | Conv2dBnRelu(1, 1, kernel_size=3, stride=2, padding=1),
230 | Conv2dBnRelu(1, 1, kernel_size=3, stride=1, padding=1)
231 | )
232 |
233 | self.conv2 = Conv2dBnRelu(1, 1, kernel_size=5, stride=1, padding=2)
234 | self.conv1 = Conv2dBnRelu(1, 1, kernel_size=7, stride=1, padding=3)
235 |
236 | def forward(self, x):
237 |
238 | h = x.size()[2]
239 | w = x.size()[3]
240 |
241 | b1 = self.branch1(x)
242 | # b1 = Interpolate(size=(h, w), mode="bilinear")(b1)
243 | b1= interpolate(b1, size=(h, w), mode="bilinear", align_corners=True)
244 |
245 | mid = self.mid(x)
246 |
247 | x1 = self.down1(x)
248 | x2 = self.down2(x1)
249 | x3 = self.down3(x2)
250 | # x3 = Interpolate(size=(h // 4, w // 4), mode="bilinear")(x3)
251 | x3= interpolate(x3, size=(h // 4, w // 4), mode="bilinear", align_corners=True)
252 | x2 = self.conv2(x2)
253 | x = x2 + x3
254 | # x = Interpolate(size=(h // 2, w // 2), mode="bilinear")(x)
255 | x= interpolate(x, size=(h // 2, w // 2), mode="bilinear", align_corners=True)
256 |
257 | x1 = self.conv1(x1)
258 | x = x + x1
259 | # x = Interpolate(size=(h, w), mode="bilinear")(x)
260 | x= interpolate(x, size=(h, w), mode="bilinear", align_corners=True)
261 |
262 | x = torch.mul(x, mid)
263 |
264 | x = x + b1
265 |
266 |
267 | return x
268 |
269 |
270 |
271 |
272 | class Decoder (nn.Module):
273 | def __init__(self, num_classes):
274 | super().__init__()
275 |
276 | self.apn = APN_Module(in_ch=128,out_ch=20)
277 | # self.upsample = Interpolate(size=(512, 1024), mode="bilinear")
278 | # self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=4, stride=2, padding=1, output_padding=0, bias=True)
279 | # self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True)
280 | # self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True)
281 |
282 | def forward(self, input):
283 |
284 | output = self.apn(input)
285 | out = interpolate(output, size=(512, 1024), mode="bilinear", align_corners=True)
286 | # out = self.upsample(output)
287 | # print(out.shape)
288 | return out
289 |
290 |
291 | # LEDNet
292 | class Net(nn.Module):
293 | def __init__(self, num_classes, encoder=None):
294 | super().__init__()
295 |
296 | if (encoder == None):
297 | self.encoder = Encoder(num_classes)
298 | else:
299 | self.encoder = encoder
300 | self.decoder = Decoder(num_classes)
301 |
302 | def forward(self, input, only_encode=False):
303 | if only_encode:
304 | return self.encoder.forward(input, predict=True)
305 | else:
306 | output = self.encoder(input)
307 | return self.decoder.forward(output)
308 |
--------------------------------------------------------------------------------
/train/lednet_1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 | from torch.nn.functional import interpolate as interpolate
6 |
7 | def split(x):
8 | c = int(x.size()[1])
9 | c1 = round(c * 0.5)
10 | x1 = x[:, :c1, :, :].contiguous()
11 | x2 = x[:, c1:, :, :].contiguous()
12 |
13 | return x1, x2
14 |
15 | def channel_shuffle(x,groups):
16 | batchsize, num_channels, height, width = x.data.size()
17 |
18 | channels_per_group = num_channels // groups
19 |
20 | #reshape
21 | x = x.view(batchsize,groups,
22 | channels_per_group,height,width)
23 |
24 | x = torch.transpose(x,1,2).contiguous()
25 |
26 | #flatten
27 | x = x.view(batchsize,-1,height,width)
28 |
29 | return x
30 |
31 |
32 | class Conv2dBnRelu(nn.Module):
33 | def __init__(self,in_ch,out_ch,kernel_size=3,stride=1,padding=0,dilation=1,bias=True):
34 | super(Conv2dBnRelu,self).__init__()
35 |
36 | self.conv = nn.Sequential(
37 | nn.Conv2d(in_ch,out_ch,kernel_size,stride,padding,dilation=dilation,bias=bias),
38 | nn.BatchNorm2d(out_ch, eps=1e-3),
39 | nn.ReLU(inplace=True)
40 | )
41 |
42 | def forward(self, x):
43 | return self.conv(x)
44 |
45 |
46 | ##after Concat -> BN, you also can use Dropout like SS_nbt_module may be make a good result!
47 | class DownsamplerBlock (nn.Module):
48 | def __init__(self, in_channel, out_channel):
49 | super(DownsamplerBlock,self).__init__()
50 |
51 | self.conv = nn.Conv2d(in_channel, out_channel-in_channel, (3, 3), stride=2, padding=1, bias=True)
52 | self.pool = nn.MaxPool2d(2, stride=2)
53 | self.bn = nn.BatchNorm2d(out_channel, eps=1e-3)
54 | self.relu = nn.ReLU(inplace=True)
55 |
56 | def forward(self, input):
57 | x1 = self.pool(input)
58 | x2 = self.conv(input)
59 |
60 | diffY = x2.size()[2] - x1.size()[2]
61 | diffX = x2.size()[3] - x1.size()[3]
62 |
63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
64 | diffY // 2, diffY - diffY // 2])
65 |
66 | output = torch.cat([x2, x1], 1)
67 | output = self.bn(output)
68 | output = self.relu(output)
69 | return output
70 |
71 |
72 | class SS_nbt_module(nn.Module):
73 | def __init__(self, chann, dropprob, dilated):
74 | super(SS_nbt_module,self).__init__()
75 |
76 | oup_inc = chann//2
77 |
78 | #dw
79 | self.conv3x1_1_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True)
80 |
81 | self.conv1x3_1_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True)
82 |
83 | self.bn1_l = nn.BatchNorm2d(oup_inc, eps=1e-03)
84 |
85 | self.conv3x1_2_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1))
86 |
87 | self.conv1x3_2_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated))
88 |
89 | self.bn2_l = nn.BatchNorm2d(oup_inc, eps=1e-03)
90 |
91 | #dw
92 | self.conv3x1_1_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True)
93 |
94 | self.conv1x3_1_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True)
95 |
96 | self.bn1_r = nn.BatchNorm2d(oup_inc, eps=1e-03)
97 |
98 | self.conv3x1_2_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1))
99 |
100 | self.conv1x3_2_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated))
101 |
102 | self.bn2_r = nn.BatchNorm2d(oup_inc, eps=1e-03)
103 |
104 | self.relu = nn.ReLU(inplace=True)
105 | self.dropout = nn.Dropout2d(dropprob)
106 |
107 | @staticmethod
108 | def _concat(x,out):
109 | return torch.cat((x,out),1)
110 |
111 | def forward(self, input):
112 |
113 | # x1 = input[:,:(input.shape[1]//2),:,:]
114 | # x2 = input[:,(input.shape[1]//2):,:,:]
115 |
116 | residual = input
117 | x1, x2 = split(input)
118 |
119 | output1 = self.conv3x1_1_l(x1)
120 | output1 = self.relu(output1)
121 | output1 = self.conv1x3_1_l(output1)
122 | output1 = self.bn1_l(output1)
123 | output1 = self.relu(output1)
124 |
125 | output1 = self.conv3x1_2_l(output1)
126 | output1 = self.relu(output1)
127 | output1 = self.conv1x3_2_l(output1)
128 | output1 = self.bn2_l(output1)
129 |
130 | output2 = self.conv1x3_1_r(x2)
131 | output2 = self.relu(output2)
132 | output2 = self.conv3x1_1_r(output2)
133 | output2 = self.bn1_r(output2)
134 | output2 = self.relu(output2)
135 |
136 | output2 = self.conv1x3_2_r(output2)
137 | output2 = self.relu(output2)
138 | output2 = self.conv3x1_2_r(output2)
139 | output2 = self.bn2_r(output2)
140 |
141 | if (self.dropout.p != 0):
142 | output1 = self.dropout(output1)
143 | output2 = self.dropout(output2)
144 |
145 | out = self._concat(output1,output2)
146 | out = F.relu(residual + out, inplace=True)
147 |
148 | return channel_shuffle(out, 2)
149 |
150 |
151 | class Encoder(nn.Module):
152 | def __init__(self, num_classes):
153 | super().__init__()
154 | self.initial_block = DownsamplerBlock(3,32)
155 |
156 |
157 | self.layers = nn.ModuleList()
158 |
159 | for x in range(0, 3):
160 | self.layers.append(SS_nbt_module(32, 0.03, 1))
161 |
162 |
163 | self.layers.append(DownsamplerBlock(32,64))
164 |
165 |
166 | for x in range(0, 2):
167 | self.layers.append(SS_nbt_module(64, 0.03, 1))
168 |
169 | self.layers.append(DownsamplerBlock(64,128))
170 |
171 | for x in range(0, 1):
172 | self.layers.append(SS_nbt_module(128, 0.3, 1))
173 | self.layers.append(SS_nbt_module(128, 0.3, 2))
174 | self.layers.append(SS_nbt_module(128, 0.3, 5))
175 | self.layers.append(SS_nbt_module(128, 0.3, 9))
176 |
177 | for x in range(0, 1):
178 | self.layers.append(SS_nbt_module(128, 0.3, 2))
179 | self.layers.append(SS_nbt_module(128, 0.3, 5))
180 | self.layers.append(SS_nbt_module(128, 0.3, 9))
181 | self.layers.append(SS_nbt_module(128, 0.3, 17))
182 |
183 |
184 | #Only in encoder mode:
185 | self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True)
186 |
187 | def forward(self, input, predict=False):
188 |
189 | output = self.initial_block(input)
190 |
191 | for layer in self.layers:
192 | output = layer(output)
193 |
194 | if predict:
195 | output = self.output_conv(output)
196 |
197 | return output
198 |
199 | class Interpolate(nn.Module):
200 | def __init__(self,size,mode):
201 | super(Interpolate,self).__init__()
202 |
203 | self.interp = nn.functional.interpolate
204 | self.size = size
205 | self.mode = mode
206 | def forward(self,x):
207 | x = self.interp(x,size=self.size,mode=self.mode,align_corners=True)
208 | return x
209 |
210 |
211 | class APN_Module(nn.Module):
212 | def __init__(self, in_ch, out_ch):
213 | super(APN_Module, self).__init__()
214 | # global pooling branch
215 | self.branch1 = nn.Sequential(
216 | nn.AdaptiveAvgPool2d(1),
217 | Conv2dBnRelu(in_ch, out_ch, kernel_size=1, stride=1, padding=0)
218 | )
219 |
220 | # midddle branch
221 | self.mid = nn.Sequential(
222 | Conv2dBnRelu(in_ch, out_ch, kernel_size=1, stride=1, padding=0)
223 | )
224 |
225 | self.down1 = Conv2dBnRelu(in_ch, 128, kernel_size=7, stride=2, padding=3)
226 |
227 | self.down2 = Conv2dBnRelu(128, 128, kernel_size=5, stride=2, padding=2)
228 |
229 | self.down3 = nn.Sequential(
230 | Conv2dBnRelu(128, 128, kernel_size=3, stride=2, padding=1),
231 | Conv2dBnRelu(128, 20, kernel_size=1, stride=1, padding=0)
232 | )
233 |
234 | self.conv2 = Conv2dBnRelu(128, 20, kernel_size=1, stride=1, padding=0)
235 | self.conv1 = Conv2dBnRelu(128, 20, kernel_size=1, stride=1, padding=0)
236 |
237 | def forward(self, x):
238 |
239 | h = x.size()[2]
240 | w = x.size()[3]
241 |
242 | b1 = self.branch1(x)
243 | #b1 = Interpolate(size=(h, w), mode="bilinear")(b1)
244 | b1= interpolate(b1, size=(h, w), mode="bilinear", align_corners=True)
245 |
246 | mid = self.mid(x)
247 |
248 | x1 = self.down1(x)
249 | x2 = self.down2(x1)
250 | x3 = self.down3(x2)
251 | #x3 = Interpolate(size=(h // 4, w // 4), mode="bilinear")(x3)
252 | x3= interpolate(x3, size=(h // 4, w // 4), mode="bilinear", align_corners=True)
253 | x2 = self.conv2(x2)
254 | x = x2 + x3
255 | #x = Interpolate(size=(h // 2, w // 2), mode="bilinear")(x)
256 | x= interpolate(x, size=(h // 2, w // 2), mode="bilinear", align_corners=True)
257 |
258 | x1 = self.conv1(x1)
259 | x = x + x1
260 | #x = Interpolate(size=(h, w), mode="bilinear")(x)
261 | x= interpolate(x, size=(h, w), mode="bilinear", align_corners=True)
262 |
263 | x = torch.mul(x, mid)
264 |
265 | x = x + b1
266 |
267 |
268 | return x
269 |
270 |
271 |
272 |
273 | class Decoder (nn.Module):
274 | def __init__(self, num_classes):
275 | super().__init__()
276 |
277 | self.apn = APN_Module(in_ch=128,out_ch=20)
278 | #self.upsample = Interpolate(size=(512, 1024), mode="bilinear")
279 | #self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=4, stride=2, padding=1, output_padding=0, bias=True)
280 | #self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True)
281 | #self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True)
282 |
283 | def forward(self, input):
284 |
285 | output = self.apn(input)
286 | out = interpolate(output, size=(512, 1024), mode="bilinear", align_corners=True)
287 | #out = self.upsample(output)
288 | return out
289 |
290 |
291 | # LEDNet
292 | class Net(nn.Module):
293 | def __init__(self, num_classes, encoder=None):
294 | super().__init__()
295 |
296 | if (encoder == None):
297 | self.encoder = Encoder(num_classes)
298 | else:
299 | self.encoder = encoder
300 | self.decoder = Decoder(num_classes)
301 |
302 | def forward(self, input, only_encode=False):
303 | if only_encode:
304 | return self.encoder.forward(input, predict=True)
305 | else:
306 | output = self.encoder(input)
307 | return self.decoder.forward(output)
308 |
--------------------------------------------------------------------------------
/train/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import time
4 | import numpy as np
5 | import torch
6 | import math
7 | import sys
8 |
9 | cur_path = os.path.abspath(os.path.dirname(__file__))
10 | root_path = os.path.split(cur_path)[0]
11 | sys.path.append(root_path)
12 |
13 |
14 |
15 | from PIL import Image, ImageOps
16 | from argparse import ArgumentParser
17 |
18 | from torch.optim import SGD, Adam, lr_scheduler
19 | from torch.autograd import Variable
20 | from torch.utils.data import DataLoader
21 | from torchvision.transforms import Compose, CenterCrop, Normalize, Resize, Pad
22 | from torchvision.transforms import ToTensor, ToPILImage
23 |
24 | from utils.dataset import VOC12,cityscapes
25 | from utils.transform import Relabel, ToLabel, Colorize
26 | from utils.visualize import Dashboard
27 | from utils.loss import CrossEntropyLoss2d
28 |
29 | import importlib
30 | from utils.iouEval import iouEval, getColorEntry
31 |
32 | from shutil import copyfile
33 |
34 | NUM_CHANNELS = 3
35 | NUM_CLASSES = 20 #pascal=22, cityscapes=20
36 |
37 | color_transform = Colorize(NUM_CLASSES)
38 | image_transform = ToPILImage()
39 |
40 | #Augmentations - different function implemented to perform random augments on both image and target
41 | class MyCoTransform(object):
42 | def __init__(self, enc, augment=True, height=512):
43 | self.enc=enc
44 | self.augment = augment
45 | self.height = height
46 | pass
47 | def __call__(self, input, target):
48 | # do something to both images
49 | input = Resize(self.height, Image.BILINEAR)(input)
50 | target = Resize(self.height, Image.NEAREST)(target)
51 |
52 | if(self.augment):
53 | # Random hflip
54 | hflip = random.random()
55 | if (hflip < 0.5):
56 | input = input.transpose(Image.FLIP_LEFT_RIGHT)
57 | target = target.transpose(Image.FLIP_LEFT_RIGHT)
58 |
59 | #Random translation 0-2 pixels (fill rest with padding
60 | transX = random.randint(-2, 2)
61 | transY = random.randint(-2, 2)
62 |
63 | input = ImageOps.expand(input, border=(transX,transY,0,0), fill=0)
64 | target = ImageOps.expand(target, border=(transX,transY,0,0), fill=255) #pad label filling with 255
65 | input = input.crop((0, 0, input.size[0]-transX, input.size[1]-transY))
66 | target = target.crop((0, 0, target.size[0]-transX, target.size[1]-transY))
67 |
68 | input = ToTensor()(input)
69 | if (self.enc):
70 | target = Resize(int(self.height/8), Image.NEAREST)(target)
71 | target = ToLabel()(target)
72 | target = Relabel(255, 19)(target)
73 |
74 | return input, target
75 |
76 |
77 |
78 |
79 |
80 | def train(args, model, enc=False):
81 | best_acc = 0
82 |
83 | #TODO: calculate weights by processing dataset histogram (now its being set by hand from the torch values)
84 | #create a loder to run all images and calculate histogram of labels, then create weight array using class balancing
85 |
86 | weight = torch.ones(NUM_CLASSES)
87 | if (enc):
88 | weight[0] = 2.3653597831726
89 | weight[1] = 4.4237880706787
90 | weight[2] = 2.9691488742828
91 | weight[3] = 5.3442072868347
92 | weight[4] = 5.2983593940735
93 | weight[5] = 5.2275490760803
94 | weight[6] = 5.4394111633301
95 | weight[7] = 5.3659925460815
96 | weight[8] = 3.4170460700989
97 | weight[9] = 5.2414722442627
98 | weight[10] = 4.7376127243042
99 | weight[11] = 5.2286224365234
100 | weight[12] = 5.455126285553
101 | weight[13] = 4.3019247055054
102 | weight[14] = 5.4264230728149
103 | weight[15] = 5.4331531524658
104 | weight[16] = 5.433765411377
105 | weight[17] = 5.4631009101868
106 | weight[18] = 5.3947434425354
107 | else:
108 | weight[0] = 2.8149201869965
109 | weight[1] = 6.9850029945374
110 | weight[2] = 3.7890393733978
111 | weight[3] = 9.9428062438965
112 | weight[4] = 9.7702074050903
113 | weight[5] = 9.5110931396484
114 | weight[6] = 10.311357498169
115 | weight[7] = 10.026463508606
116 | weight[8] = 4.6323022842407
117 | weight[9] = 9.5608062744141
118 | weight[10] = 7.8698215484619
119 | weight[11] = 9.5168733596802
120 | weight[12] = 10.373730659485
121 | weight[13] = 6.6616044044495
122 | weight[14] = 10.260489463806
123 | weight[15] = 10.287888526917
124 | weight[16] = 10.289801597595
125 | weight[17] = 10.405355453491
126 | weight[18] = 10.138095855713
127 |
128 | weight[19] = 0
129 |
130 | assert os.path.exists(args.datadir), "Error: datadir (dataset directory) could not be loaded"
131 |
132 | co_transform = MyCoTransform(enc, augment=True, height=args.height)#512)
133 | co_transform_val = MyCoTransform(enc, augment=False, height=args.height)#512)
134 | dataset_train = cityscapes(args.datadir, co_transform, 'train')
135 | dataset_val = cityscapes(args.datadir, co_transform_val, 'val')
136 |
137 | loader = DataLoader(dataset_train, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)
138 | loader_val = DataLoader(dataset_val, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)
139 |
140 | if args.cuda:
141 | weight = weight.cuda()
142 | criterion = CrossEntropyLoss2d(weight)
143 | print(type(criterion))
144 |
145 | savedir = f'../save/{args.savedir}'
146 |
147 | if (enc):
148 | automated_log_path = savedir + "/automated_log_encoder.txt"
149 | modeltxtpath = savedir + "/model_encoder.txt"
150 | else:
151 | automated_log_path = savedir + "/automated_log.txt"
152 | modeltxtpath = savedir + "/model.txt"
153 |
154 | if (not os.path.exists(automated_log_path)): #dont add first line if it exists
155 | with open(automated_log_path, "a") as myfile:
156 | myfile.write("Epoch\t\tTrain-loss\t\tTest-loss\t\tTrain-IoU\t\tTest-IoU\t\tlearningRate")
157 |
158 | with open(modeltxtpath, "w") as myfile:
159 | myfile.write(str(model))
160 |
161 |
162 | #TODO: reduce memory in first gpu: https://discuss.pytorch.org/t/multi-gpu-training-memory-usage-in-balance/4163/4
163 | #https://github.com/pytorch/pytorch/issues/1893
164 |
165 | #optimizer = Adam(model.parameters(), 5e-4, (0.9, 0.999), eps=1e-08, weight_decay=2e-4) ## scheduler 1
166 | optimizer = Adam(model.parameters(), 5e-4, (0.9, 0.999), eps=1e-08, weight_decay=1e-4) ## scheduler 2
167 |
168 | start_epoch = 1
169 | if args.resume:
170 | #Must load weights, optimizer, epoch and best value.
171 | if enc:
172 | filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
173 | else:
174 | filenameCheckpoint = savedir + '/checkpoint.pth.tar'
175 |
176 | assert os.path.exists(filenameCheckpoint), "Error: resume option was used but checkpoint was not found in folder"
177 | checkpoint = torch.load(filenameCheckpoint)
178 | start_epoch = checkpoint['epoch']
179 | model.load_state_dict(checkpoint['state_dict'])
180 | optimizer.load_state_dict(checkpoint['optimizer'])
181 | best_acc = checkpoint['best_acc']
182 | print("=> Loaded checkpoint at epoch {})".format(checkpoint['epoch']))
183 |
184 | #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5) # set up scheduler ## scheduler 1
185 | lambda1 = lambda epoch: pow((1-((epoch-1)/args.num_epochs)),0.9) ## scheduler 2
186 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) ## scheduler 2
187 |
188 | if args.visualize and args.steps_plot > 0:
189 | board = Dashboard(args.port)
190 |
191 | for epoch in range(start_epoch, args.num_epochs+1):
192 | print("----- TRAINING - EPOCH", epoch, "-----")
193 |
194 | scheduler.step(epoch) ## scheduler 2
195 |
196 | epoch_loss = []
197 | time_train = []
198 |
199 | doIouTrain = args.iouTrain
200 | doIouVal = args.iouVal
201 |
202 | if (doIouTrain):
203 | iouEvalTrain = iouEval(NUM_CLASSES)
204 |
205 | usedLr = 0
206 | for param_group in optimizer.param_groups:
207 | print("LEARNING RATE: ", param_group['lr'])
208 | usedLr = float(param_group['lr'])
209 |
210 | model.train()
211 | for step, (images, labels) in enumerate(loader):
212 |
213 | start_time = time.time()
214 |
215 | imgs_batch = images.shape[0]
216 | if imgs_batch != args.batch_size:
217 | break
218 |
219 | if args.cuda:
220 | inputs = images.cuda()
221 | targets = labels.cuda()
222 |
223 | outputs = model(inputs, only_encode=enc)
224 |
225 | #print("targets", np.unique(targets[:, 0].cpu().data.numpy()))
226 |
227 | optimizer.zero_grad()
228 | loss = criterion(outputs, targets[:, 0])
229 |
230 | loss.backward()
231 | optimizer.step()
232 |
233 | epoch_loss.append(loss.item())
234 | time_train.append(time.time() - start_time)
235 |
236 | if (doIouTrain):
237 | #start_time_iou = time.time()
238 | iouEvalTrain.addBatch(outputs.max(1)[1].unsqueeze(1).data, targets.data)
239 | #print ("Time to add confusion matrix: ", time.time() - start_time_iou)
240 |
241 | #print(outputs.size())
242 | if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0:
243 | start_time_plot = time.time()
244 | image = inputs[0].cpu().data
245 | #image[0] = image[0] * .229 + .485
246 | #image[1] = image[1] * .224 + .456
247 | #image[2] = image[2] * .225 + .406
248 | #print("output", np.unique(outputs[0].cpu().max(0)[1].data.numpy()))
249 | board.image(image, f'input (epoch: {epoch}, step: {step})')
250 | if isinstance(outputs, list): #merge gpu tensors
251 | board.image(color_transform(outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)),
252 | f'output (epoch: {epoch}, step: {step})')
253 | else:
254 | board.image(color_transform(outputs[0].cpu().max(0)[1].data.unsqueeze(0)),
255 | f'output (epoch: {epoch}, step: {step})')
256 | board.image(color_transform(targets[0].cpu().data),
257 | f'target (epoch: {epoch}, step: {step})')
258 | print ("Time to paint images: ", time.time() - start_time_plot)
259 | if args.steps_loss > 0 and step % args.steps_loss == 0:
260 | average = sum(epoch_loss) / len(epoch_loss)
261 | print(f'loss: {average:0.4} (epoch: {epoch}, step: {step})',
262 | "// Avg time/img: %.4f s" % (sum(time_train) / len(time_train) / args.batch_size))
263 |
264 |
265 | average_epoch_loss_train = sum(epoch_loss) / len(epoch_loss)
266 |
267 | iouTrain = 0
268 | if (doIouTrain):
269 | iouTrain, iou_classes = iouEvalTrain.getIoU()
270 | iouStr = getColorEntry(iouTrain)+'{:0.2f}'.format(iouTrain*100) + '\033[0m'
271 | print ("EPOCH IoU on TRAIN set: ", iouStr, "%")
272 |
273 | #Validate on 500 val images after each epoch of training
274 | print("----- VALIDATING - EPOCH", epoch, "-----")
275 | model.eval()
276 | epoch_loss_val = []
277 | time_val = []
278 |
279 | if (doIouVal):
280 | iouEvalVal = iouEval(NUM_CLASSES)
281 |
282 | for step, (images, labels) in enumerate(loader_val):
283 | start_time = time.time()
284 |
285 | imgs_batch = images.shape[0]
286 | if imgs_batch != args.batch_size:
287 | break
288 |
289 | if args.cuda:
290 | images = images.cuda()
291 | labels = labels.cuda()
292 |
293 | with torch.no_grad():
294 | inputs = Variable(images)
295 | targets = Variable(labels)
296 |
297 | outputs = model(inputs, only_encode=enc)
298 |
299 | loss = criterion(outputs, targets[:, 0])
300 | epoch_loss_val.append(loss.item())
301 | time_val.append(time.time() - start_time)
302 |
303 |
304 | #Add batch to calculate TP, FP and FN for iou estimation
305 | if (doIouVal):
306 | #start_time_iou = time.time()
307 | iouEvalVal.addBatch(outputs.max(1)[1].unsqueeze(1).data, targets.data)
308 | #print ("Time to add confusion matrix: ", time.time() - start_time_iou)
309 |
310 | if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0:
311 | start_time_plot = time.time()
312 | image = inputs[0].cpu().data
313 | board.image(image, f'VAL input (epoch: {epoch}, step: {step})')
314 | if isinstance(outputs, list): #merge gpu tensors
315 | board.image(color_transform(outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)),
316 | f'VAL output (epoch: {epoch}, step: {step})')
317 | else:
318 | board.image(color_transform(outputs[0].cpu().max(0)[1].data.unsqueeze(0)),
319 | f'VAL output (epoch: {epoch}, step: {step})')
320 | board.image(color_transform(targets[0].cpu().data),
321 | f'VAL target (epoch: {epoch}, step: {step})')
322 | print ("Time to paint images: ", time.time() - start_time_plot)
323 | if args.steps_loss > 0 and step % args.steps_loss == 0:
324 | average = sum(epoch_loss_val) / len(epoch_loss_val)
325 | print(f'VAL loss: {average:0.4} (epoch: {epoch}, step: {step})',
326 | "// Avg time/img: %.4f s" % (sum(time_val) / len(time_val) / args.batch_size))
327 |
328 |
329 | average_epoch_loss_val = sum(epoch_loss_val) / len(epoch_loss_val)
330 | #scheduler.step(average_epoch_loss_val, epoch) ## scheduler 1 # update lr if needed
331 |
332 | iouVal = 0
333 | if (doIouVal):
334 | iouVal, iou_classes = iouEvalVal.getIoU()
335 | iouStr = getColorEntry(iouVal)+'{:0.2f}'.format(iouVal*100) + '\033[0m'
336 | print ("EPOCH IoU on VAL set: ", iouStr, "%")
337 |
338 |
339 | # remember best valIoU and save checkpoint
340 | if iouVal == 0:
341 | current_acc = -average_epoch_loss_val
342 | else:
343 | current_acc = iouVal
344 | is_best = current_acc > best_acc
345 | best_acc = max(current_acc, best_acc)
346 | if enc:
347 | filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
348 | filenameBest = savedir + '/model_best_enc.pth.tar'
349 | else:
350 | filenameCheckpoint = savedir + '/checkpoint.pth.tar'
351 | filenameBest = savedir + '/model_best.pth.tar'
352 | save_checkpoint({
353 | 'epoch': epoch + 1,
354 | 'arch': str(model),
355 | 'state_dict': model.state_dict(),
356 | 'best_acc': best_acc,
357 | 'optimizer' : optimizer.state_dict(),
358 | }, is_best, filenameCheckpoint, filenameBest)
359 |
360 | #SAVE MODEL AFTER EPOCH
361 | if (enc):
362 | filename = f'{savedir}/model_encoder-{epoch:03}.pth'
363 | filenamebest = f'{savedir}/model_encoder_best.pth'
364 | else:
365 | filename = f'{savedir}/model-{epoch:03}.pth'
366 | filenamebest = f'{savedir}/model_best.pth'
367 | if args.epochs_save > 0 and step > 0 and step % args.epochs_save == 0:
368 | torch.save(model.state_dict(), filename)
369 | print(f'save: {filename} (epoch: {epoch})')
370 | if (is_best):
371 | torch.save(model.state_dict(), filenamebest)
372 | print(f'save: {filenamebest} (epoch: {epoch})')
373 | if (not enc):
374 | with open(savedir + "/best.txt", "w") as myfile:
375 | myfile.write("Best epoch is %d, with Val-IoU= %.4f" % (epoch, iouVal))
376 | else:
377 | with open(savedir + "/best_encoder.txt", "w") as myfile:
378 | myfile.write("Best epoch is %d, with Val-IoU= %.4f" % (epoch, iouVal))
379 |
380 | #SAVE TO FILE A ROW WITH THE EPOCH RESULT (train loss, val loss, train IoU, val IoU)
381 | #Epoch Train-loss Test-loss Train-IoU Test-IoU learningRate
382 | with open(automated_log_path, "a") as myfile:
383 | myfile.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.8f" % (epoch, average_epoch_loss_train, average_epoch_loss_val, iouTrain, iouVal, usedLr ))
384 |
385 | return(model) #return model (convenience for encoder-decoder training)
386 |
387 | def save_checkpoint(state, is_best, filenameCheckpoint, filenameBest):
388 | torch.save(state, filenameCheckpoint)
389 | if is_best:
390 | print ("Saving model as best")
391 | torch.save(state, filenameBest)
392 |
393 |
394 | def main(args):
395 | savedir = f'../save/{args.savedir}'
396 |
397 | if not os.path.exists(savedir):
398 | os.makedirs(savedir)
399 |
400 | with open(savedir + '/opts.txt', "w") as myfile:
401 | myfile.write(str(args))
402 |
403 | #Load Model
404 | assert os.path.exists(args.model + ".py"), "Error: model definition not found"
405 | model_file = importlib.import_module(args.model)
406 | model = model_file.Net(NUM_CLASSES)
407 | copyfile(args.model + ".py", savedir + '/' + args.model + ".py")
408 |
409 | if args.cuda:
410 | model = torch.nn.DataParallel(model).cuda()
411 |
412 | if args.state:
413 | #if args.state is provided then load this state for training
414 | #Note: this only loads initialized weights. If you want to resume a training use "--resume" option!!
415 | """
416 | try:
417 | model.load_state_dict(torch.load(args.state))
418 | except AssertionError:
419 | model.load_state_dict(torch.load(args.state,
420 | map_location=lambda storage, loc: storage))
421 | #When model is saved as DataParallel it adds a model. to each key. To remove:
422 | #state_dict = {k.partition('model.')[2]: v for k,v in state_dict}
423 | #https://discuss.pytorch.org/t/prefix-parameter-names-in-saved-model-if-trained-by-multi-gpu/494
424 | """
425 | def load_my_state_dict(model, state_dict): #custom function to load model when not all dict keys are there
426 | own_state = model.state_dict()
427 | for name, param in state_dict.items():
428 | if name not in own_state:
429 | continue
430 | own_state[name].copy_(param)
431 | return model
432 |
433 | #print(torch.load(args.state))
434 | model = load_my_state_dict(model, torch.load(args.state))
435 |
436 | """
437 | def weights_init(m):
438 | classname = m.__class__.__name__
439 | if classname.find('Conv') != -1:
440 | #m.weight.data.normal_(0.0, 0.02)
441 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
442 | m.weight.data.normal_(0, math.sqrt(2. / n))
443 | elif classname.find('BatchNorm') != -1:
444 | #m.weight.data.normal_(1.0, 0.02)
445 | m.weight.data.fill_(1)
446 | m.bias.data.fill_(0)
447 |
448 | #TO ACCESS MODEL IN DataParallel: next(model.children())
449 | #next(model.children()).decoder.apply(weights_init)
450 | #Reinitialize weights for decoder
451 |
452 | next(model.children()).decoder.layers.apply(weights_init)
453 | next(model.children()).decoder.output_conv.apply(weights_init)
454 |
455 | #print(model.state_dict())
456 | f = open('weights5.txt', 'w')
457 | f.write(str(model.state_dict()))
458 | f.close()
459 | """
460 |
461 | #train(args, model)
462 | if (not args.decoder):
463 | print("========== ENCODER TRAINING ===========")
464 | model = train(args, model, True) #Train encoder
465 | #CAREFUL: for some reason, after training encoder alone, the decoder gets weights=0.
466 | #We must reinit decoder weights or reload network passing only encoder in order to train decoder
467 | print("========== DECODER TRAINING ===========")
468 | if (not args.state):
469 | if args.pretrainedEncoder:
470 | print("Loading encoder pretrained in imagenet")
471 | from lednet_imagenet import LEDNet as LEDNet_imagenet
472 | pretrainedEnc = torch.nn.DataParallel(LEDNet_imagenet(1000))
473 | pretrainedEnc.load_state_dict(torch.load(args.pretrainedEncoder)['state_dict'])
474 | pretrainedEnc = next(pretrainedEnc.children()).features.encoder
475 | if (not args.cuda):
476 | pretrainedEnc = pretrainedEnc.cpu() #because loaded encoder is probably saved in cuda
477 | else:
478 | pretrainedEnc = next(model.children()).encoder
479 | model = model_file.Net(NUM_CLASSES, encoder=pretrainedEnc) #Add decoder to encoder
480 | if args.cuda:
481 | model = torch.nn.DataParallel(model).cuda()
482 | #When loading encoder reinitialize weights for decoder because they are set to 0 when training dec
483 | model = train(args, model, False) #Train decoder
484 | print("========== TRAINING FINISHED ===========")
485 |
486 | if __name__ == '__main__':
487 | parser = ArgumentParser()
488 | parser.add_argument('--cuda', action='store_true', default=True) #NOTE: cpu-only has not been tested so you might have to change code if you deactivate this flag
489 | parser.add_argument('--model', default= "lednet")
490 | parser.add_argument('--state')
491 |
492 | parser.add_argument('--port', type=int, default=8097)
493 | parser.add_argument('--datadir', default=os.getenv("HOME") + "/datasets/cityscapes/")
494 | parser.add_argument('--height', type=int, default=512)
495 | parser.add_argument('--num-epochs', type=int, default=300)
496 | parser.add_argument('--num-workers', type=int, default=4)
497 | parser.add_argument('--batch-size', type=int, default=5)
498 | parser.add_argument('--steps-loss', type=int, default=50)
499 | parser.add_argument('--steps-plot', type=int, default=50)
500 | parser.add_argument('--epochs-save', type=int, default=0) #You can use this value to save model every X epochs
501 | parser.add_argument('--savedir', required=True)
502 | parser.add_argument('--decoder', action='store_true')
503 | parser.add_argument('--pretrainedEncoder') #, default=" ")
504 | parser.add_argument('--visualize', action='store_true')
505 |
506 | parser.add_argument('--iouTrain', action='store_true', default=True) #recommended: False (takes more time to train otherwise)
507 | parser.add_argument('--iouVal', action='store_true', default=True)
508 | parser.add_argument('--resume', action='store_true') #Use this flag to load last checkpoint for training
509 |
510 | main(parser.parse_args())
511 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/utils/__init__.py
--------------------------------------------------------------------------------
/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 |
4 | from PIL import Image
5 |
6 | from torch.utils.data import Dataset
7 |
8 | EXTENSIONS = ['.jpg', '.png']
9 |
10 | def load_image(file):
11 | return Image.open(file)
12 |
13 | def is_image(filename):
14 | return any(filename.endswith(ext) for ext in EXTENSIONS)
15 |
16 | def is_label(filename):
17 | return filename.endswith("_labelTrainIds.png")
18 |
19 | def image_path(root, basename, extension):
20 | return os.path.join(root, f'{basename}{extension}')
21 |
22 | def image_path_city(root, name):
23 | return os.path.join(root, f'{name}')
24 |
25 | def image_basename(filename):
26 | return os.path.basename(os.path.splitext(filename)[0])
27 |
28 | class VOC12(Dataset):
29 |
30 | def __init__(self, root, input_transform=None, target_transform=None):
31 | self.images_root = os.path.join(root, 'images')
32 | self.labels_root = os.path.join(root, 'labels')
33 |
34 | self.filenames = [image_basename(f)
35 | for f in os.listdir(self.labels_root) if is_image(f)]
36 | self.filenames.sort()
37 |
38 | self.input_transform = input_transform
39 | self.target_transform = target_transform
40 |
41 | def __getitem__(self, index):
42 | filename = self.filenames[index]
43 |
44 | with open(image_path(self.images_root, filename, '.jpg'), 'rb') as f:
45 | image = load_image(f).convert('RGB')
46 | with open(image_path(self.labels_root, filename, '.png'), 'rb') as f:
47 | label = load_image(f).convert('P')
48 |
49 | if self.input_transform is not None:
50 | image = self.input_transform(image)
51 | if self.target_transform is not None:
52 | label = self.target_transform(label)
53 |
54 | return image, label
55 |
56 | def __len__(self):
57 | return len(self.filenames)
58 |
59 |
60 |
61 |
62 | class cityscapes(Dataset):
63 |
64 | def __init__(self, root, co_transform=None, subset='train'):
65 | self.images_root = os.path.join(root, 'leftImg8bit/')
66 | self.labels_root = os.path.join(root, 'gtFine/')
67 | #self.labels_root = os.path.join(root, 'gtCoarse/')
68 | self.images_root += subset
69 | self.labels_root += subset
70 |
71 | print (self.images_root)
72 | #self.filenames = [image_basename(f) for f in os.listdir(self.images_root) if is_image(f)]
73 | self.filenames = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(self.images_root)) for f in fn if is_image(f)]
74 | self.filenames.sort()
75 |
76 | #[os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(".")) for f in fn]
77 | #self.filenamesGt = [image_basename(f) for f in os.listdir(self.labels_root) if is_image(f)]
78 | self.filenamesGt = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(self.labels_root)) for f in fn if is_label(f)]
79 | self.filenamesGt.sort()
80 |
81 | self.co_transform = co_transform # ADDED THIS
82 |
83 |
84 | def __getitem__(self, index):
85 | filename = self.filenames[index]
86 | filenameGt = self.filenamesGt[index]
87 |
88 | with open(image_path_city(self.images_root, filename), 'rb') as f:
89 | image = load_image(f).convert('RGB')
90 | with open(image_path_city(self.labels_root, filenameGt), 'rb') as f:
91 | label = load_image(f).convert('P')
92 |
93 | if self.co_transform is not None:
94 | image, label = self.co_transform(image, label)
95 |
96 | return image, label
97 |
98 | def __len__(self):
99 | return len(self.filenames)
100 |
101 |
--------------------------------------------------------------------------------
/utils/iouEval.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class iouEval:
4 |
5 | def __init__(self, nClasses, ignoreIndex=19):
6 | self.nClasses = nClasses
7 | self.ignoreIndex = ignoreIndex if nClasses>ignoreIndex else -1 #if ignoreIndex is larger than nClasses, consider no ignoreIndex
8 | self.reset()
9 |
10 | def reset (self):
11 | classes = self.nClasses if self.ignoreIndex==-1 else self.nClasses-1
12 | self.tp = torch.zeros(classes).double()
13 | self.fp = torch.zeros(classes).double()
14 | self.fn = torch.zeros(classes).double()
15 |
16 | def addBatch(self, x, y): #x=preds, y=targets
17 | #sizes should be "batch_size x nClasses x H x W"
18 |
19 | #print ("X is cuda: ", x.is_cuda)
20 | #print ("Y is cuda: ", y.is_cuda)
21 |
22 | if (x.is_cuda or y.is_cuda):
23 | x = x.cuda()
24 | y = y.cuda()
25 |
26 | #if size is "batch_size x 1 x H x W" scatter to onehot
27 | if (x.size(1) == 1):
28 | x_onehot = torch.zeros(x.size(0), self.nClasses, x.size(2), x.size(3))
29 | if x.is_cuda:
30 | x_onehot = x_onehot.cuda()
31 | x_onehot.scatter_(1, x, 1).float()
32 | else:
33 | x_onehot = x.float()
34 |
35 | if (y.size(1) == 1):
36 | y_onehot = torch.zeros(y.size(0), self.nClasses, y.size(2), y.size(3))
37 | if y.is_cuda:
38 | y_onehot = y_onehot.cuda()
39 | y_onehot.scatter_(1, y, 1).float()
40 | else:
41 | y_onehot = y.float()
42 |
43 | if (self.ignoreIndex != -1):
44 | ignores = y_onehot[:,self.ignoreIndex].unsqueeze(1)
45 | x_onehot = x_onehot[:, :self.ignoreIndex]
46 | y_onehot = y_onehot[:, :self.ignoreIndex]
47 | else:
48 | ignores=0
49 |
50 | #print(type(x_onehot))
51 | #print(type(y_onehot))
52 | #print(x_onehot.size())
53 | #print(y_onehot.size())
54 |
55 | tpmult = x_onehot * y_onehot #times prediction and gt coincide is 1
56 | tp = torch.sum(torch.sum(torch.sum(tpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze()
57 | fpmult = x_onehot * (1-y_onehot-ignores) #times prediction says its that class and gt says its not (subtracting cases when its ignore label!)
58 | fp = torch.sum(torch.sum(torch.sum(fpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze()
59 | fnmult = (1-x_onehot) * (y_onehot) #times prediction says its not that class and gt says it is
60 | fn = torch.sum(torch.sum(torch.sum(fnmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze()
61 |
62 | self.tp += tp.double().cpu()
63 | self.fp += fp.double().cpu()
64 | self.fn += fn.double().cpu()
65 |
66 | def getIoU(self):
67 | num = self.tp
68 | den = self.tp + self.fp + self.fn + 1e-15
69 | iou = num / den
70 | return torch.mean(iou), iou #returns "iou mean", "iou per class"
71 |
72 | # Class for colors
73 | class colors:
74 | RED = '\033[31;1m'
75 | GREEN = '\033[32;1m'
76 | YELLOW = '\033[33;1m'
77 | BLUE = '\033[34;1m'
78 | MAGENTA = '\033[35;1m'
79 | CYAN = '\033[36;1m'
80 | BOLD = '\033[1m'
81 | UNDERLINE = '\033[4m'
82 | ENDC = '\033[0m'
83 |
84 | # Colored value output if colorized flag is activated.
85 | def getColorEntry(val):
86 | if not isinstance(val, float):
87 | return colors.ENDC
88 | if (val < .20):
89 | return colors.RED
90 | elif (val < .40):
91 | return colors.YELLOW
92 | elif (val < .60):
93 | return colors.BLUE
94 | elif (val < .80):
95 | return colors.CYAN
96 | else:
97 | return colors.GREEN
98 |
99 |
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class CrossEntropyLoss2d(torch.nn.Module):
7 |
8 | def __init__(self, weight=None):
9 | super(CrossEntropyLoss2d,self).__init__()
10 |
11 | self.loss = nn.NLLLoss(weight)
12 |
13 | def forward(self, outputs, targets):
14 | return self.loss(F.log_softmax(outputs, dim=1), targets)
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/utils/transform.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from PIL import Image
5 |
6 | def colormap_cityscapes(n):
7 | cmap=np.zeros([n, 3]).astype(np.uint8)
8 | cmap[0,:] = np.array([128, 64,128])
9 | cmap[1,:] = np.array([244, 35,232])
10 | cmap[2,:] = np.array([ 70, 70, 70])
11 | cmap[3,:] = np.array([ 102,102,156])
12 | cmap[4,:] = np.array([ 190,153,153])
13 | cmap[5,:] = np.array([ 153,153,153])
14 |
15 | cmap[6,:] = np.array([ 250,170, 30])
16 | cmap[7,:] = np.array([ 220,220, 0])
17 | cmap[8,:] = np.array([ 107,142, 35])
18 | cmap[9,:] = np.array([ 152,251,152])
19 | cmap[10,:] = np.array([ 70,130,180])
20 |
21 | cmap[11,:] = np.array([ 220, 20, 60])
22 | cmap[12,:] = np.array([ 255, 0, 0])
23 | cmap[13,:] = np.array([ 0, 0,142])
24 | cmap[14,:] = np.array([ 0, 0, 70])
25 | cmap[15,:] = np.array([ 0, 60,100])
26 |
27 | cmap[16,:] = np.array([ 0, 80,100])
28 | cmap[17,:] = np.array([ 0, 0,230])
29 | cmap[18,:] = np.array([ 119, 11, 32])
30 | cmap[19,:] = np.array([ 0, 0, 0])
31 |
32 | return cmap
33 |
34 |
35 | def colormap(n):
36 | cmap=np.zeros([n, 3]).astype(np.uint8)
37 |
38 | for i in np.arange(n):
39 | r, g, b = np.zeros(3)
40 |
41 | for j in np.arange(8):
42 | r = r + (1<<(7-j))*((i&(1<<(3*j))) >> (3*j))
43 | g = g + (1<<(7-j))*((i&(1<<(3*j+1))) >> (3*j+1))
44 | b = b + (1<<(7-j))*((i&(1<<(3*j+2))) >> (3*j+2))
45 |
46 | cmap[i,:] = np.array([r, g, b])
47 |
48 | return cmap
49 |
50 | class Relabel:
51 |
52 | def __init__(self, olabel, nlabel):
53 | self.olabel = olabel
54 | self.nlabel = nlabel
55 |
56 | def __call__(self, tensor):
57 | assert (isinstance(tensor, torch.LongTensor) or isinstance(tensor, torch.ByteTensor)) , 'tensor needs to be LongTensor'
58 | tensor[tensor == self.olabel] = self.nlabel
59 | return tensor
60 |
61 |
62 | class ToLabel:
63 |
64 | def __call__(self, image):
65 | return torch.from_numpy(np.array(image)).long().unsqueeze(0)
66 |
67 |
68 | class Colorize:
69 |
70 | def __init__(self, n=22):
71 | #self.cmap = colormap(256)
72 | self.cmap = colormap_cityscapes(256)
73 | self.cmap[n] = self.cmap[-1]
74 | self.cmap = torch.from_numpy(self.cmap[:n])
75 |
76 | def __call__(self, gray_image):
77 | size = gray_image.size()
78 | #print(size)
79 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
80 | #color_image = torch.ByteTensor(3, size[0], size[1]).fill_(0)
81 |
82 | #for label in range(1, len(self.cmap)):
83 | for label in range(0, len(self.cmap)):
84 | mask = gray_image[0] == label
85 | #mask = gray_image == label
86 |
87 | color_image[0][mask] = self.cmap[label][0]
88 | color_image[1][mask] = self.cmap[label][1]
89 | color_image[2][mask] = self.cmap[label][2]
90 |
91 | return color_image
92 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from torch.autograd import Variable
4 |
5 | from visdom import Visdom
6 |
7 | class Dashboard:
8 |
9 | def __init__(self, port):
10 | self.vis = Visdom(port=port)
11 |
12 | def loss(self, losses, title):
13 | x = np.arange(1, len(losses)+1, 1)
14 |
15 | self.vis.line(losses, x, env='loss', opts=dict(title=title))
16 |
17 | def image(self, image, title):
18 | if image.is_cuda:
19 | image = image.cpu()
20 | if isinstance(image, Variable):
21 | image = image.data
22 | image = image.numpy()
23 |
24 | self.vis.image(image, env='images', opts=dict(title=title))
--------------------------------------------------------------------------------