├── .gitignore ├── LICENSE ├── README.md ├── checkpoint └── .gitkeep ├── datasets ├── README.md ├── cityscapes │ ├── gtCoarse │ │ └── .gitkeep │ ├── gtFine │ │ └── .gitkeep │ ├── leftImg8bit │ │ └── .gitkeep │ └── results │ │ └── .gitkeep └── cityscapesscripts │ └── .gitkeep ├── imagenet-pretrain ├── README.md ├── lednet_imagenet.py └── main.py ├── images ├── LEDNet_demo.png └── LEDNet_overview.png ├── requirements.txt ├── save └── .gitkeep ├── test ├── README.md ├── dataset.py ├── eval_cityscapes_color.py ├── eval_cityscapes_server.py ├── eval_forward_time.py ├── eval_iou.py ├── iouEval.py ├── lednet_no_bn.py └── transform.py ├── train ├── README.md ├── lednet.py ├── lednet_1.py └── main.py └── utils ├── __init__.py ├── dataset.py ├── iouEval.py ├── loss.py ├── transform.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | #Files 2 | *.pyc 3 | *.pyo 4 | */__pycache__/ 5 | */*/__pycache__/ 6 | */*/*/__pycache__/ 7 | test/save_results/ 8 | test/save_color/ 9 | .idea/ 10 | 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yu Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### [LEDNet: A Lightweight Encoder-Decoder Network for Real-time Semantic Segmentation](https://github.com/xiaoyufenfei/LEDNet) 2 | 3 | [![python-image]][python-url] 4 | [![pytorch-image]][pytorch-url] 5 | 6 | #### Table of Contents: 7 | - Introduction 8 | - Project Structure 9 | - Installation 10 | - Datasets 11 | - Train 12 | - Resuming training 13 | - Test 14 | - Results 15 | - Reference 16 | - Tips 17 | 18 | #### Introduction 19 | 20 | This project contains the code (Note: The code is test in the environment with python=3.6, cuda=9.0, PyTorch-0.4.1, also support Pytorch-0.4.1+) for: [**LEDNet: A Lightweight Encoder-Decoder Network for Real-time Semantic Segmentation**](https://arxiv.org/pdf/1905.02423.pdf) by [Yu Wang](https://github.com/xiaoyufenfei). 21 | 22 |

23 | The extensive computational burden limits the usage of CNNs in mobile devices for dense estimation tasks, a.k.a semantic segmentation. In this paper, we present a lightweight network to address this problem, namely **LEDNet**, which employs an asymmetric encoder-decoder architecture for the task of real-time semantic segmentation.More specifically, the encoder adopts a ResNet as backbone network, where two new operations, channel split and shuffle, are utilized in each residual block to greatly reduce computation cost while maintaining higher segmentation accuracy. On the other hand, an attention pyramid network (APN) is employed in the decoder to further lighten the entire network complexity. Our model has less than 1M parameters, and is able to run at over 71 FPS on a single GTX 1080Ti GPU card. The comprehensive experiments demonstrate that our approach achieves state-of-the-art results in terms of speed and accuracy trade-off on Cityscapes dataset. and becomes an effective method for real-time semantic segmentation tasks. 24 | 25 | #### Project-Structure 26 | ``` 27 | ├── datasets # contains all datasets for the project 28 | | └── cityscapes # cityscapes dataset 29 | | | └── gtCoarse # Coarse cityscapes annotation 30 | | | └── gtFine # Fine cityscapes annotation 31 | | | └── leftImg8bit # cityscapes training image 32 | | └── cityscapesscripts # cityscapes dataset label convert scripts! 33 | ├── utils 34 | | └── dataset.py # dataloader for cityscapes dataset 35 | | └── iouEval.py # for test 'iou mean' and 'iou per class' 36 | | └── transform.py # data preprocessing 37 | | └── visualize.py # Visualize with visdom 38 | | └── loss.py # loss function 39 | ├── checkpoint 40 | | └── xxx.pth # pretrained models encoder form ImageNet 41 | ├── save 42 | | └── xxx.pth # trained models form scratch 43 | ├── imagenet-pretrain 44 | | └── lednet_imagenet.py # 45 | | └── main.py # 46 | ├── train 47 | | └── lednet.py # model definition for semantic segmentation 48 | | └── main.py # train model scripts 49 | ├── test 50 | | | └── dataset.py 51 | | | └── lednet.py # model definition 52 | | | └── lednet_no_bn.py # Remove the BN layer in model definition 53 | | | └── eval_cityscapes_color.py # Test the results to generate RGB images 54 | | | └── eval_cityscapes_server.py # generate result uploaded official server 55 | | | └── eval_forward_time.py # Test model inference time 56 | | | └── eval_iou.py 57 | | | └── iouEval.py 58 | | | └── transform.py 59 | ``` 60 | 61 | #### Installation 62 | - Python 3.6.x. Recommended using [Anaconda3](https://www.anaconda.com/distribution/) 63 | - Set up python environment 64 | 65 | ``` 66 | pip3 install -r requirements.txt 67 | ``` 68 | 69 | - Env: PyTorch_0.4.1; cuda_9.0; cudnn_7.1; python_3.6, 70 | 71 | - Clone this repository. 72 | 73 | ``` 74 | git clone https://github.com/xiaoyufenfei/LEDNet.git 75 | cd LEDNet-master 76 | ``` 77 | 78 | - Install [Visdom](https://github.com/facebookresearch/visdom). 79 | - Install [torchsummary](https://github.com/sksq96/pytorch-summary) 80 | - Download the dataset by following the **Datasets** below. 81 | - Note: For training, we currently support [cityscapes](https://www.cityscapes-dataset.com/) , aim to add [Camvid](https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid) and [VOC](http://host.robots.ox.ac.uk/pascal/VOC/) and [ADE20K](http://groups.csail.mit.edu/vision/datasets/ADE20K/) dataset 82 | 83 | #### Datasets 84 | - You can download [cityscapes](https://www.cityscapes-dataset.com/) from [here](https://www.cityscapes-dataset.com/downloads/). Note: please download [leftImg8bit_trainvaltest.zip(11GB)](https://www.cityscapes-dataset.com/file-handling/?packageID=4) and [gtFine_trainvaltest(241MB)](https://www.cityscapes-dataset.com/file-handling/?packageID=1) and [gtCoarse(1.3GB)](https://www.cityscapes-dataset.com/file-handling/?packageID=1). 85 | - You can download [CityscapesScripts](https://github.com/mcordts/cityscapesScripts), and convert the dataset to [19 categories](https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py). It should have this basic structure. 86 | 87 | ``` 88 | ├── leftImg8bit 89 | │ ├── train 90 | │ ├── val 91 | │ └── test 92 | ├── gtFine 93 | │ ├── train 94 | │ ├── val 95 | │ └── test 96 | ├── gtCoarse 97 | │ ├── train 98 | │ ├── train_extra 99 | │ └── val 100 | ``` 101 | 102 | #### Training-LEDNet 103 | 104 | - For help on the optional arguments you can run: `python main.py -h` 105 | 106 | - By default, we assume you have downloaded the cityscapes dataset in the `./data/cityscapes` dir. 107 | - To train LEDNet using the train/main.py script the parameters listed in `main.py` as a flag or manually change them. 108 | 109 | ``` 110 | python main.py --savedir logs --model lednet --datadir path/root_directory/ --num-epochs xx --batch-size xx ... 111 | ``` 112 | 113 | #### Resuming-training-if-decoder-part-broken 114 | 115 | - for help on the optional arguments you can run: `python main.py -h` 116 | 117 | ``` 118 | python main.py --savedir logs --name lednet --datadir path/root_directory/ --num-epochs xx --batch-size xx --decoder --state "../save/logs/model_best_enc.pth.tar"... 119 | ``` 120 | 121 | #### Testing 122 | 123 | - the trained models of training process can be found at [here](https://github.com/xiaoyufenfei/LEDNet/save/). This may not be the best one, you can train one from scratch by yourself or Fine-tuning the training decoder with model encoder pre-trained on ImageNet, For instance 124 | 125 | ``` 126 | more details refer ./test/README.md 127 | ``` 128 | 129 | #### Results 130 | 131 | - Please refer to our article for more details. 132 | 133 | |Method|Dataset|Fine|Coarse| IoU_cla |IoU_cat|FPS| 134 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:| 135 | |**LEDNet**|**cityscapes**|yes|yes|**70.6​%**|**87.1​%​**|**70​+​**| 136 | 137 | qualitative segmentation result examples: 138 | 139 |

140 | 141 | #### Citation 142 | 143 | If you find this code useful for your research, please use the following BibTeX entry. 144 | 145 | ``` 146 | @article{wang2019lednet, 147 | title={LEDNet: A Lightweight Encoder-Decoder Network for Real-time Semantic Segmentation}, 148 | author={Wang, Yu and Zhou, Quan and Liu, Jia and Xiong,Jian and Gao, Guangwei and Wu, Xiaofu, and Latecki Jan Longin}, 149 | journal={arXiv preprint arXiv:1905.02423}, 150 | year={2019} 151 | } 152 | ``` 153 | 154 | #### Tips 155 | 156 | - Limited by GPU resources, the project results need to be further improved... 157 | - It is recommended to pre-train Encoder on ImageNet and then Fine-turning Decoder part. The result will be better. 158 | 159 | #### Reference 160 | 161 | 1. [**Deep residual learning for image recognition**](https://arxiv.org/pdf/1512.03385.pdf) 162 | 2. [**Enet: A deep neural network architecture for real-time semantic segmentation**](https://arxiv.org/pdf/1606.02147.pdf) 163 | 3. [**Erfnet: Efficient residual factorized convnet for real-time semantic segmentation**](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8063438) 164 | 4. [**Shufflenet: An extremely efficient convolutional neural network for mobile devices**](https://arxiv.org/pdf/1707.01083.pdf) 165 | 166 | 170 | 171 | [python-image]: https://img.shields.io/badge/Python-3.x-ff69b4.svg 172 | [python-url]: https://www.python.org/ 173 | [pytorch-image]: https://img.shields.io/badge/PyTorch-1.0-2BAF2B.svg 174 | [pytorch-url]: https://pytorch.org/ -------------------------------------------------------------------------------- /checkpoint/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/checkpoint/.gitkeep -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | #### Cityscapes dataset overview: 2 | 3 | 1. The data is **registered and downloaded **on the official website, https://www.cityscapes-dataset.com/. The official website can also view the indicators reached by everyone's neural network in the benchmarks. 4 | 5 | 2. data preprocessing and evaluation results code **download**: https://github.com/mcordts/cityscapesScripts 6 | 7 | The original image is stored in the **leftImg8bit** folder, and the finely labeled data is stored in the **gtFine** (gt: ground truth) folder. The training set consists of **2975** trains and the validation set is **500** sheets (val), all of which have corresponding labels. But the test set (test) only gave the original picture, no label, the official used to evaluate the code submitted by everyone (to prevent someone from using the test set training brush indicator). Therefore, in actual use, you can use the validation set to do the test. Coarse labeled data is stored in the **gtCoarse** (gt: ground truth) folder. 8 | 9 | Each image in the tag file corresponds to 4 files, where _gtFine_polygons.json stores the classes and corresponding regions (the boundary of the region is represented by the position of the vertices of the polygon); the value of _gtFine_labelIds.png is 0-33, different values Representing different classes, the correspondence between values and classes is defined in the code in cityscapesscripts/helpers/labels.py; _gtFine_instaceIds.png is an example split; _gtFine_color.png is for everyone to visualize, and the correspondence between different colors and categories is also Description in the labels.py file. 10 | 11 | #### Dataset Structure: 12 | 13 | ``` 14 | ├── datasets # contains all datasets for the project 15 | | └── cityscapes # cityscapes dataset 16 | | | └── gtCoarse # Coarse cityscapes annotation 17 | | | └── gtFine # Fine cityscapes annotation 18 | | | └── leftImg8bit # cityscapes training image 19 | | | └── results #results move here for eval by evalPixelLevelSemanticLabeling.py 20 | | └── cityscapesscripts # cityscapes dataset label convert scripts! 21 | | | └── annotation # 22 | | | └── evalution # 23 | | | └── helps # 24 | | | └── preparation # 25 | | | └── viewer # 26 | | | └── __init__.py # 27 | 28 | ``` 29 | 30 | -------------------------------------------------------------------------------- /datasets/cityscapes/gtCoarse/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/datasets/cityscapes/gtCoarse/.gitkeep -------------------------------------------------------------------------------- /datasets/cityscapes/gtFine/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/datasets/cityscapes/gtFine/.gitkeep -------------------------------------------------------------------------------- /datasets/cityscapes/leftImg8bit/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/datasets/cityscapes/leftImg8bit/.gitkeep -------------------------------------------------------------------------------- /datasets/cityscapes/results/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/datasets/cityscapes/results/.gitkeep -------------------------------------------------------------------------------- /datasets/cityscapesscripts/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/datasets/cityscapesscripts/.gitkeep -------------------------------------------------------------------------------- /imagenet-pretrain/README.md: -------------------------------------------------------------------------------- 1 | #### Model and ImageNet pretraining script 2 | 3 | This folder contains the script and model definition to pretrain LEDNet's encoder in ImageNet Data. 4 | 5 | The script is an adaptation from the code in [Pytorch Imagenet example](https://github.com/pytorch/examples/tree/master/imagenet). Please make sure that you have ImageNet dataset split in train and val folders before launching the script. Refer to that repository for instructions about usage and main.py options. Basic command: 6 | 7 | ``` 8 | python main.py 9 | ``` 10 | 11 | 12 | #### Third Party Project Reference 13 | 14 | - [ImageNet training in PyTorch](https://github.com/pytorch/examples/tree/master/imagenet) 15 | 16 | - [ShuffleNetv2 in PyTorch](https://github.com/Randl/ShuffleNetV2-pytorch) 17 | 18 | -------------------------------------------------------------------------------- /imagenet-pretrain/lednet_imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | def split(x): 7 | c = int(x.size()[1]) 8 | c1 = round(c * 0.5) 9 | x1 = x[:, :c1, :, :].contiguous() 10 | x2 = x[:, c1:, :, :].contiguous() 11 | 12 | return x1, x2 13 | 14 | def channel_shuffle(x,groups): 15 | batchsize, num_channels, height, width = x.data.size() 16 | 17 | channels_per_group = num_channels // groups 18 | 19 | # reshape 20 | x = x.view(batchsize,groups, 21 | channels_per_group,height,width) 22 | 23 | x = torch.transpose(x,1,2).contiguous() 24 | 25 | # flatten 26 | x = x.view(batchsize,-1,height,width) 27 | 28 | return x 29 | 30 | 31 | class Conv2dBnRelu(nn.Module): 32 | def __init__(self,in_ch,out_ch,kernel_size=3,stride=1,padding=0,dilation=1,bias=True): 33 | super(Conv2dBnRelu,self).__init__() 34 | 35 | self.conv = nn.Sequential( 36 | nn.Conv2d(in_ch,out_ch,kernel_size,stride,padding,dilation=dilation,bias=bias), 37 | nn.BatchNorm2d(out_ch, eps=1e-3), 38 | nn.ReLU(inplace=True) 39 | ) 40 | 41 | def forward(self, x): 42 | return self.conv(x) 43 | 44 | 45 | # after Concat -> BN, you also can use Dropout like SS_nbt_module may be make a good result! 46 | class DownsamplerBlock (nn.Module): 47 | def __init__(self, in_channel, out_channel): 48 | super(DownsamplerBlock,self).__init__() 49 | 50 | self.conv = nn.Conv2d(in_channel, out_channel-in_channel, (3, 3), stride=2, padding=1, bias=True) 51 | self.pool = nn.MaxPool2d(2, stride=2) 52 | self.bn = nn.BatchNorm2d(out_channel, eps=1e-3) 53 | self.relu = nn.ReLU(inplace=True) 54 | 55 | def forward(self, input): 56 | x1 = self.pool(input) 57 | x2 = self.conv(input) 58 | 59 | diffY = x2.size()[2] - x1.size()[2] 60 | diffX = x2.size()[3] - x1.size()[3] 61 | 62 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 63 | diffY // 2, diffY - diffY // 2]) 64 | 65 | output = torch.cat([x2, x1], 1) 66 | output = self.bn(output) 67 | output = self.relu(output) 68 | return output 69 | 70 | 71 | class SS_nbt_module(nn.Module): 72 | def __init__(self, chann, dropprob, dilated): 73 | super().__init__() 74 | 75 | oup_inc = chann//2 76 | 77 | # dw 78 | self.conv3x1_1_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True) 79 | 80 | self.conv1x3_1_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True) 81 | 82 | self.bn1_l = nn.BatchNorm2d(oup_inc, eps=1e-03) 83 | 84 | self.conv3x1_2_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1)) 85 | 86 | self.conv1x3_2_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated)) 87 | 88 | self.bn2_l = nn.BatchNorm2d(oup_inc, eps=1e-03) 89 | 90 | # dw 91 | self.conv3x1_1_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True) 92 | 93 | self.conv1x3_1_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True) 94 | 95 | self.bn1_r = nn.BatchNorm2d(oup_inc, eps=1e-03) 96 | 97 | self.conv3x1_2_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1)) 98 | 99 | self.conv1x3_2_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated)) 100 | 101 | self.bn2_r = nn.BatchNorm2d(oup_inc, eps=1e-03) 102 | 103 | self.relu = nn.ReLU(inplace=True) 104 | self.dropout = nn.Dropout2d(dropprob) 105 | 106 | @staticmethod 107 | def _concat(x,out): 108 | return torch.cat((x,out),1) 109 | 110 | def forward(self, input): 111 | 112 | # x1 = input[:,:(input.shape[1]//2),:,:] 113 | # x2 = input[:,(input.shape[1]//2):,:,:] 114 | residual = input 115 | x1, x2 = split(input) 116 | 117 | output1 = self.conv3x1_1_l(x1) 118 | output1 = self.relu(output1) 119 | output1 = self.conv1x3_1_l(output1) 120 | output1 = self.bn1_l(output1) 121 | output1 = self.relu(output1) 122 | 123 | output1 = self.conv3x1_2_l(output1) 124 | output1 = self.relu(output1) 125 | output1 = self.conv1x3_2_l(output1) 126 | output1 = self.bn2_l(output1) 127 | 128 | 129 | output2 = self.conv1x3_1_r(x2) 130 | output2 = self.relu(output2) 131 | output2 = self.conv3x1_1_r(output2) 132 | output2 = self.bn1_r(output2) 133 | output2 = self.relu(output2) 134 | 135 | output2 = self.conv1x3_2_r(output2) 136 | output2 = self.relu(output2) 137 | output2 = self.conv3x1_2_r(output2) 138 | output2 = self.bn2_r(output2) 139 | 140 | if (self.dropout.p != 0): 141 | output1 = self.dropout(output1) 142 | output2 = self.dropout(output2) 143 | 144 | out = self._concat(output1,output2) 145 | out = F.relu(residual + out, inplace=True) 146 | return channel_shuffle(out, 2) 147 | 148 | 149 | class Encoder(nn.Module): 150 | def __init__(self): 151 | super().__init__() 152 | self.initial_block = DownsamplerBlock(3,32) 153 | 154 | self.layers = nn.ModuleList() 155 | 156 | for x in range(0, 3): 157 | self.layers.append(SS_nbt_module(32, 0.03, 1)) 158 | 159 | 160 | self.layers.append(DownsamplerBlock(32,64)) 161 | 162 | 163 | for x in range(0, 2): 164 | self.layers.append(SS_nbt_module(64, 0.03, 1)) 165 | 166 | self.layers.append(DownsamplerBlock(64,128)) 167 | 168 | for x in range(0, 1): 169 | self.layers.append(SS_nbt_module(128, 0.3, 1)) 170 | self.layers.append(SS_nbt_module(128, 0.3, 2)) 171 | self.layers.append(SS_nbt_module(128, 0.3, 5)) 172 | self.layers.append(SS_nbt_module(128, 0.3, 9)) 173 | 174 | for x in range(0, 1): 175 | self.layers.append(SS_nbt_module(128, 0.3, 2)) 176 | self.layers.append(SS_nbt_module(128, 0.3, 5)) 177 | self.layers.append(SS_nbt_module(128, 0.3, 9)) 178 | self.layers.append(SS_nbt_module(128, 0.3, 17)) 179 | 180 | 181 | def forward(self, input): 182 | output = self.initial_block(input) 183 | 184 | for layer in self.layers: 185 | output = layer(output) 186 | 187 | return output 188 | 189 | 190 | class Features(nn.Module): 191 | def __init__(self): 192 | super().__init__() 193 | self.encoder = Encoder() 194 | self.extralayer1 = nn.MaxPool2d(2, stride=2) 195 | self.extralayer2 = nn.AvgPool2d(14,1,0) 196 | 197 | def forward(self, input): 198 | output = self.encoder(input) 199 | output = self.extralayer1(output) 200 | output = self.extralayer2(output) 201 | return output 202 | 203 | class Classifier(nn.Module): 204 | def __init__(self, num_classes): 205 | super().__init__() 206 | self.linear = nn.Linear(128, num_classes) 207 | 208 | def forward(self, input): 209 | output = input.view(input.size(0), 128) #first is batch_size 210 | output = self.linear(output) 211 | return output 212 | 213 | class LEDNet(nn.Module): 214 | def __init__(self, num_classes): #use encoder to pass pretrained encoder 215 | super().__init__() 216 | 217 | self.features = Features() 218 | self.classifier = Classifier(num_classes) 219 | 220 | def forward(self, input): 221 | output = self.features(input) 222 | output = self.classifier(output) 223 | return output 224 | 225 | -------------------------------------------------------------------------------- /imagenet-pretrain/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim 11 | import torch.utils.data 12 | import torchvision.transforms as transforms 13 | import torchvision.datasets as datasets 14 | import torchvision.models as models 15 | 16 | from torch.optim import lr_scheduler 17 | 18 | from lednet_imagenet import LEDNet 19 | 20 | model_names = sorted(name for name in models.__dict__ 21 | if name.islower() and not name.startswith("__") 22 | and callable(models.__dict__[name])) 23 | 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 26 | parser.add_argument('data', metavar='DIR', 27 | help='path to dataset') 28 | parser.add_argument('--arch', '-a', metavar='ARCH', default='lednet', 29 | #choices=model_names, 30 | help='model architecture: ' + 31 | ' | '.join(model_names) + 32 | ' (default: resnet18)') 33 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 34 | help='number of data loading workers (default: 4)') 35 | parser.add_argument('--epochs', default=150, type=int, metavar='N', 36 | help='number of total epochs to run') 37 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 38 | help='manual epoch number (useful on restarts)') 39 | parser.add_argument('-b', '--batch-size', default=256, type=int, 40 | metavar='N', help='mini-batch size (default: 256)') 41 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 42 | metavar='LR', help='initial learning rate') 43 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 44 | help='momentum') 45 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 46 | metavar='W', help='weight decay (default: 1e-4)') 47 | parser.add_argument('--print-freq', '-p', default=10, type=int, 48 | metavar='N', help='print frequency (default: 10)') 49 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 50 | help='path to latest checkpoint (default: none)') 51 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 52 | help='evaluate model on validation set') 53 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 54 | help='use pre-trained model') 55 | 56 | best_prec1 = 0 57 | 58 | def main(): 59 | global args, best_prec1 60 | args = parser.parse_args() 61 | 62 | # create model 63 | if (args.arch == 'lednet'): 64 | model = LEDNet(1000) 65 | else: 66 | if args.pretrained: 67 | print("=> using pre-trained model '{}'".format(args.arch)) 68 | model = models.__dict__[args.arch](pretrained=True) 69 | else: 70 | print("=> creating model '{}'".format(args.arch)) 71 | model = models.__dict__[args.arch]() 72 | 73 | model = torch.nn.DataParallel(model).cuda() 74 | 75 | # define loss function (criterion) and optimizer 76 | criterion = nn.CrossEntropyLoss().cuda() 77 | 78 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 79 | momentum=args.momentum, 80 | weight_decay=args.weight_decay) 81 | 82 | # optionally resume from a checkpoint 83 | if args.resume: 84 | if os.path.isfile(args.resume): 85 | print("=> loading checkpoint '{}'".format(args.resume)) 86 | checkpoint = torch.load(args.resume) 87 | args.start_epoch = checkpoint['epoch'] 88 | best_prec1 = checkpoint['best_prec1'] 89 | model.load_state_dict(checkpoint['state_dict']) 90 | optimizer.load_state_dict(checkpoint['optimizer']) 91 | print("=> loaded checkpoint '{}' (epoch {})" 92 | .format(args.resume, checkpoint['epoch'])) 93 | else: 94 | print("=> no checkpoint found at '{}'".format(args.resume)) 95 | 96 | cudnn.benchmark = True 97 | 98 | # Data loading code 99 | traindir = os.path.join(args.data, 'train') 100 | valdir = os.path.join(args.data, 'val') 101 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 102 | std=[0.229, 0.224, 0.225]) 103 | 104 | train_loader = torch.utils.data.DataLoader( 105 | datasets.ImageFolder(traindir, transforms.Compose([ 106 | #RemoveExif(), 107 | transforms.RandomResizedCrop(224), 108 | transforms.RandomHorizontalFlip(), 109 | transforms.ToTensor(), 110 | normalize, 111 | ])), 112 | batch_size=args.batch_size, shuffle=True, 113 | num_workers=args.workers, pin_memory=True) 114 | 115 | val_loader = torch.utils.data.DataLoader( 116 | datasets.ImageFolder(valdir, transforms.Compose([ 117 | transforms.Resize(256), 118 | transforms.CenterCrop(224), 119 | transforms.ToTensor(), 120 | normalize, 121 | ])), 122 | batch_size=args.batch_size, shuffle=False, 123 | num_workers=args.workers, pin_memory=True) 124 | 125 | if args.evaluate: 126 | validate(val_loader, model, criterion) 127 | return 128 | 129 | for epoch in range(args.start_epoch, args.epochs): 130 | adjust_learning_rate(optimizer, epoch) 131 | 132 | # train for one epoch 133 | train(train_loader, model, criterion, optimizer, epoch) 134 | 135 | # evaluate on validation set 136 | prec1 = validate(val_loader, model, criterion) 137 | 138 | # remember best prec@1 and save checkpoint 139 | is_best = prec1 > best_prec1 140 | best_prec1 = max(prec1, best_prec1) 141 | save_checkpoint({ 142 | 'epoch': epoch + 1, 143 | 'arch': args.arch, 144 | 'state_dict': model.state_dict(), 145 | 'best_prec1': best_prec1, 146 | 'optimizer' : optimizer.state_dict(), 147 | }, is_best) 148 | 149 | #scheduler.step(prec1, epoch) #decreases learning rate if prec1 plateaus 150 | 151 | 152 | def train(train_loader, model, criterion, optimizer, epoch): 153 | batch_time = AverageMeter() 154 | data_time = AverageMeter() 155 | losses = AverageMeter() 156 | top1 = AverageMeter() 157 | top5 = AverageMeter() 158 | 159 | # switch to train mode 160 | model.train() 161 | 162 | end = time.time() 163 | for i, (input, target) in enumerate(train_loader): 164 | # measure data loading time 165 | data_time.update(time.time() - end) 166 | 167 | target = target.cuda(async=True) 168 | input_var = torch.autograd.Variable(input) 169 | target_var = torch.autograd.Variable(target) 170 | 171 | # compute output 172 | output = model(input_var) 173 | loss = criterion(output, target_var) 174 | 175 | # measure accuracy and record loss 176 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 177 | losses.update(loss.data[0], input.size(0)) 178 | top1.update(prec1[0], input.size(0)) 179 | top5.update(prec5[0], input.size(0)) 180 | 181 | # compute gradient and do SGD step 182 | optimizer.zero_grad() 183 | loss.backward() 184 | optimizer.step() 185 | 186 | # measure elapsed time 187 | batch_time.update(time.time() - end) 188 | end = time.time() 189 | 190 | if i % args.print_freq == 0: 191 | for param_group in optimizer.param_groups: 192 | lr = param_group['lr'] 193 | print('Epoch: [{0}][{1}/{2}][lr:{lr:.6g}]\t' 194 | 'Time {batch_time.val:.3f} ({batch_time.avg:.2f}) / ' 195 | 'Data {data_time.val:.3f} ({data_time.avg:.2f})\t' 196 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 197 | 'Prec@1 {top1.val:.2f} ({top1.avg:.2f})\t' 198 | 'Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format( 199 | epoch, i, len(train_loader), batch_time=batch_time, 200 | data_time=data_time, loss=losses, top1=top1, top5=top5, lr=lr)) 201 | 202 | 203 | def validate(val_loader, model, criterion): 204 | batch_time = AverageMeter() 205 | losses = AverageMeter() 206 | top1 = AverageMeter() 207 | top5 = AverageMeter() 208 | 209 | # switch to evaluate mode 210 | model.eval() 211 | 212 | end = time.time() 213 | for i, (input, target) in enumerate(val_loader): 214 | target = target.cuda(async=True) 215 | input_var = torch.autograd.Variable(input, volatile=True) 216 | target_var = torch.autograd.Variable(target, volatile=True) 217 | 218 | # compute output 219 | output = model(input_var) 220 | loss = criterion(output, target_var) 221 | 222 | # measure accuracy and record loss 223 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 224 | losses.update(loss.data[0], input.size(0)) 225 | top1.update(prec1[0], input.size(0)) 226 | top5.update(prec5[0], input.size(0)) 227 | 228 | # measure elapsed time 229 | batch_time.update(time.time() - end) 230 | end = time.time() 231 | 232 | if i % args.print_freq == 0: 233 | print('Test: [{0}/{1}]\t' 234 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 235 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 236 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 237 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 238 | i, len(val_loader), batch_time=batch_time, loss=losses, 239 | top1=top1, top5=top5)) 240 | 241 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 242 | .format(top1=top1, top5=top5)) 243 | 244 | return top1.avg 245 | 246 | 247 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 248 | torch.save(state, filename) 249 | if is_best: 250 | shutil.copyfile(filename, 'model_best.pth.tar') 251 | 252 | 253 | class AverageMeter(object): 254 | """Computes and stores the average and current value""" 255 | def __init__(self): 256 | self.reset() 257 | 258 | def reset(self): 259 | self.val = 0 260 | self.avg = 0 261 | self.sum = 0 262 | self.count = 0 263 | 264 | def update(self, val, n=1): 265 | self.val = val 266 | self.sum += val * n 267 | self.count += n 268 | self.avg = self.sum / self.count 269 | 270 | 271 | def adjust_learning_rate(optimizer, epoch): 272 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 273 | lr = args.lr 274 | wd = 1e-4 275 | milestone = 15 #after epoch milestone, lr is reduced exponentially 276 | if epoch > milestone: 277 | lr = args.lr * (0.95 ** (epoch-milestone)) 278 | wd = 0 279 | for param_group in optimizer.param_groups: 280 | param_group['lr'] = lr 281 | param_group['weight_decay'] = wd 282 | 283 | def accuracy(output, target, topk=(1,)): 284 | """Computes the precision@k for the specified values of k""" 285 | maxk = max(topk) 286 | batch_size = target.size(0) 287 | 288 | _, pred = output.topk(maxk, 1, True, True) 289 | pred = pred.t() 290 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 291 | 292 | res = [] 293 | for k in topk: 294 | correct_k = correct[:k].view(-1).float().sum(0) 295 | res.append(correct_k.mul_(100.0 / batch_size)) 296 | return res 297 | 298 | 299 | if __name__ == '__main__': 300 | main() 301 | -------------------------------------------------------------------------------- /images/LEDNet_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/images/LEDNet_demo.png -------------------------------------------------------------------------------- /images/LEDNet_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/images/LEDNet_overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch=0.4.1 2 | torchvision=0.2.1 3 | torchsummary==1.5.1 4 | visdom==0.1.8.4 5 | -------------------------------------------------------------------------------- /save/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/save/.gitkeep -------------------------------------------------------------------------------- /test/README.md: -------------------------------------------------------------------------------- 1 | #### Functions for evaluating/visualizing the network's output 2 | 3 | Currently there are 4 usable functions to evaluate stuff: 4 | - eval_cityscapes_color 5 | - eval_cityscapes_server 6 | - eval_iou 7 | - eval_forward_time 8 | 9 | #### eval_cityscapes_server.py 10 | 11 | This code can be used to produce segmentation of the Cityscapes images and convert the output indices to the original 'labelIds' so it can be evaluated using the scripts from Cityscapes dataset (evalPixelLevelSemanticLabeling.py) or uploaded to Cityscapes test server. By default it saves images in eval/save_results/ folder. 12 | 13 | **Options:** Specify the Cityscapes folder path with '--datadir' option. Select the cityscapes subset with '--subset' ('val', 'test', 'train' or 'demoSequence'). For other options check the bottom side of the file. 14 | 15 | **Examples:** 16 | ``` 17 | python eval_cityscapes_server.py --datadir /xx/datasets/cityscapes/ --loadDir ../save/logs/ --loadWeights model_best.pth --loadModel lednet.py --subset val 18 | ``` 19 | 20 | #### eval_cityscapes_color.py 21 | 22 | This code can be used to produce segmentation of the Cityscapes images in color for visualization purposes. By default it saves images in eval/save_color/ folder. You can also visualize results in visdom with --visualize flag. 23 | 24 | **Options:** Specify the Cityscapes folder path with '--datadir' option. Select the cityscapes subset with '--subset' ('val', 'test', 'train' or 'demoSequence'). For other options check the bottom side of the file. 25 | 26 | **Examples:** 27 | 28 | ``` 29 | python eval_cityscapes_color.py --datadir /xx/datasets/cityscapes/ --loadDir ../save/logs/ --loadWeights model_best.pth --loadModel lednet.py --subset val 30 | ``` 31 | 32 | #### eval_iou.py 33 | 34 | This code can be used to calculate the IoU (mean and per-class) in a subset of images with labels available, like Cityscapes val/train sets. 35 | 36 | **Options:** Specify the Cityscapes folder path with '--datadir' option. Select the cityscapes subset with '--subset' ('val' or 'train'). For other options check the bottom side of the file. 37 | 38 | **Examples:** 39 | 40 | ``` 41 | python eval_iou.py --datadir /xx/datasets/cityscapes/ --loadDir ../save/logs/ --loadWeights model_best.pth --loadModel lednet.py --subset val 42 | ``` 43 | 44 | #### eval_forward_time.py 45 | This function loads a model specified by '-m' and enters a loop to continuously estimate forward pass time (fwt) in the specified resolution. 46 | 47 | **Options:** Option '--width' specifies the width (default: 1024). Option '--height' specifies the height (default: 512). For other options check the bottom side of the file. 48 | 49 | **Examples:** 50 | ``` 51 | python eval_forward_time.py --batch-size=6 52 | ``` 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /test/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | from PIL import Image 5 | 6 | from torch.utils.data import Dataset 7 | 8 | EXTENSIONS = ['.jpg', '.png'] 9 | 10 | def load_image(file): 11 | return Image.open(file) 12 | 13 | def is_image(filename): 14 | return any(filename.endswith(ext) for ext in EXTENSIONS) 15 | 16 | def is_label(filename): 17 | return filename.endswith("_labelTrainIds.png") 18 | 19 | def image_path(root, basename, extension): 20 | return os.path.join(root, f'{basename}{extension}') 21 | 22 | def image_path_city(root, name): 23 | return os.path.join(root, f'{name}') 24 | 25 | def image_basename(filename): 26 | return os.path.basename(os.path.splitext(filename)[0]) 27 | 28 | class VOC12(Dataset): 29 | 30 | def __init__(self, root, input_transform=None, target_transform=None): 31 | self.images_root = os.path.join(root, 'images') 32 | self.labels_root = os.path.join(root, 'labels') 33 | 34 | self.filenames = [image_basename(f) 35 | for f in os.listdir(self.labels_root) if is_image(f)] 36 | self.filenames.sort() 37 | 38 | self.input_transform = input_transform 39 | self.target_transform = target_transform 40 | 41 | def __getitem__(self, index): 42 | filename = self.filenames[index] 43 | 44 | with open(image_path(self.images_root, filename, '.jpg'), 'rb') as f: 45 | image = load_image(f).convert('RGB') 46 | with open(image_path(self.labels_root, filename, '.png'), 'rb') as f: 47 | label = load_image(f).convert('P') 48 | 49 | if self.input_transform is not None: 50 | image = self.input_transform(image) 51 | if self.target_transform is not None: 52 | label = self.target_transform(label) 53 | 54 | return image, label 55 | 56 | def __len__(self): 57 | return len(self.filenames) 58 | 59 | 60 | class cityscapes(Dataset): 61 | 62 | def __init__(self, root, input_transform=None, target_transform=None, subset='val'): 63 | self.images_root = os.path.join(root, 'leftImg8bit/' + subset) 64 | self.labels_root = os.path.join(root, 'gtFine/' + subset) 65 | 66 | self.filenames = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(self.images_root)) for f in fn if is_image(f)] 67 | self.filenames.sort() 68 | 69 | self.filenamesGt = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(self.labels_root)) for f in fn if is_label(f)] 70 | self.filenamesGt.sort() 71 | 72 | self.input_transform = input_transform 73 | self.target_transform = target_transform 74 | 75 | def __getitem__(self, index): 76 | filename = self.filenames[index] 77 | filenameGt = self.filenamesGt[index] 78 | 79 | #print(filename) 80 | 81 | with open(image_path_city(self.images_root, filename), 'rb') as f: 82 | image = load_image(f).convert('RGB') 83 | with open(image_path_city(self.labels_root, filenameGt), 'rb') as f: 84 | label = load_image(f).convert('P') 85 | 86 | if self.input_transform is not None: 87 | image = self.input_transform(image) 88 | if self.target_transform is not None: 89 | label = self.target_transform(label) 90 | 91 | return image, label, filename, filenameGt 92 | 93 | def __len__(self): 94 | return len(self.filenames) 95 | 96 | -------------------------------------------------------------------------------- /test/eval_cityscapes_color.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | import importlib 5 | 6 | from PIL import Image 7 | from argparse import ArgumentParser 8 | 9 | from torch.autograd import Variable 10 | from torch.utils.data import DataLoader 11 | from torchvision.transforms import Compose, CenterCrop, Normalize, Resize 12 | from torchvision.transforms import ToTensor, ToPILImage 13 | 14 | from dataset import cityscapes 15 | 16 | 17 | 18 | from lednet import Net 19 | 20 | 21 | 22 | from transform import Relabel, ToLabel, Colorize 23 | 24 | import visdom 25 | 26 | 27 | NUM_CHANNELS = 3 28 | NUM_CLASSES = 20 29 | 30 | image_transform = ToPILImage() 31 | input_transform_cityscapes = Compose([ 32 | Resize((512,1024),Image.BILINEAR), 33 | ToTensor(), 34 | #Normalize([.485, .456, .406], [.229, .224, .225]), 35 | ]) 36 | target_transform_cityscapes = Compose([ 37 | Resize((512,1024),Image.NEAREST), 38 | ToLabel(), 39 | Relabel(255, 19), #ignore label to 19 40 | ]) 41 | 42 | cityscapes_trainIds2labelIds = Compose([ 43 | Relabel(19, 255), 44 | Relabel(18, 33), 45 | Relabel(17, 32), 46 | Relabel(16, 31), 47 | Relabel(15, 28), 48 | Relabel(14, 27), 49 | Relabel(13, 26), 50 | Relabel(12, 25), 51 | Relabel(11, 24), 52 | Relabel(10, 23), 53 | Relabel(9, 22), 54 | Relabel(8, 21), 55 | Relabel(7, 20), 56 | Relabel(6, 19), 57 | Relabel(5, 17), 58 | Relabel(4, 13), 59 | Relabel(3, 12), 60 | Relabel(2, 11), 61 | Relabel(1, 8), 62 | Relabel(0, 7), 63 | Relabel(255, 0), 64 | ToPILImage(), 65 | ]) 66 | 67 | def main(args): 68 | 69 | modelpath = args.loadDir + args.loadModel 70 | weightspath = args.loadDir + args.loadWeights 71 | 72 | print ("Loading model: " + modelpath) 73 | print ("Loading weights: " + weightspath) 74 | 75 | model = Net(NUM_CLASSES) 76 | 77 | model = torch.nn.DataParallel(model) 78 | if (not args.cpu): 79 | model = model.cuda() 80 | 81 | #model.load_state_dict(torch.load(args.state)) 82 | #model.load_state_dict(torch.load(weightspath)) #not working if missing key 83 | 84 | def load_my_state_dict(model, state_dict): #custom function to load model when not all dict elements 85 | own_state = model.state_dict() 86 | for name, param in state_dict.items(): 87 | if name not in own_state: 88 | continue 89 | own_state[name].copy_(param) 90 | return model 91 | 92 | model = load_my_state_dict(model, torch.load(weightspath)) 93 | print ("Model and weights LOADED successfully") 94 | 95 | model.eval() 96 | 97 | if(not os.path.exists(args.datadir)): 98 | print ("Error: datadir could not be loaded") 99 | 100 | 101 | loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset), 102 | num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) 103 | 104 | # For visualizer: 105 | # must launch in other window "python3.6 -m visdom.server -port 8097" 106 | # and access localhost:8097 to see it 107 | if (args.visualize): 108 | vis = visdom.Visdom() 109 | 110 | for step, (images, labels, filename, filenameGt) in enumerate(loader): 111 | if (not args.cpu): 112 | images = images.cuda() 113 | #labels = labels.cuda() 114 | 115 | inputs = Variable(images) 116 | #targets = Variable(labels) 117 | with torch.no_grad(): 118 | outputs = model(inputs) 119 | 120 | label = outputs[0].max(0)[1].byte().cpu().data 121 | #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0)) 122 | label_color = Colorize()(label.unsqueeze(0)) 123 | 124 | filenameSave = "./save_color/" + filename[0].split("leftImg8bit/")[1] 125 | os.makedirs(os.path.dirname(filenameSave), exist_ok=True) 126 | #image_transform(label.byte()).save(filenameSave) 127 | label_save = ToPILImage()(label_color) 128 | label_save.save(filenameSave) 129 | 130 | if (args.visualize): 131 | vis.image(label_color.numpy()) 132 | print (step, filenameSave) 133 | 134 | 135 | 136 | if __name__ == '__main__': 137 | parser = ArgumentParser() 138 | 139 | parser.add_argument('--state') 140 | 141 | 142 | parser.add_argument('--loadDir',default="../save/logs/") 143 | parser.add_argument('--loadWeights', default="model_best.pth") 144 | parser.add_argument('--loadModel', default="lednet.py") 145 | parser.add_argument('--subset', default="val") #can be val, test, train, demoSequence 146 | 147 | parser.add_argument('--datadir', default=os.getenv("HOME") + "/datasets/cityscapes/") 148 | parser.add_argument('--num-workers', type=int, default=4) 149 | parser.add_argument('--batch-size', type=int, default=1) 150 | parser.add_argument('--cpu', action='store_true') 151 | 152 | parser.add_argument('--visualize', action='store_true') 153 | main(parser.parse_args()) 154 | -------------------------------------------------------------------------------- /test/eval_cityscapes_server.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | import importlib 5 | 6 | from PIL import Image 7 | from argparse import ArgumentParser 8 | 9 | from torch.autograd import Variable 10 | from torch.utils.data import DataLoader 11 | from torchvision.transforms import Compose, CenterCrop, Normalize, Resize 12 | from torchvision.transforms import ToTensor, ToPILImage 13 | 14 | from dataset import cityscapes 15 | 16 | from lednet import Net 17 | 18 | 19 | from transform import Relabel, ToLabel, Colorize 20 | 21 | 22 | NUM_CHANNELS = 3 23 | NUM_CLASSES = 20 24 | 25 | image_transform = ToPILImage() 26 | input_transform_cityscapes = Compose([ 27 | Resize(512), 28 | ToTensor(), 29 | #Normalize([.485, .456, .406], [.229, .224, .225]), 30 | ]) 31 | target_transform_cityscapes = Compose([ 32 | Resize(512), 33 | ToLabel(), 34 | Relabel(255, 19), #ignore label to 19 35 | ]) 36 | 37 | cityscapes_trainIds2labelIds = Compose([ 38 | Relabel(19, 255), 39 | Relabel(18, 33), 40 | Relabel(17, 32), 41 | Relabel(16, 31), 42 | Relabel(15, 28), 43 | Relabel(14, 27), 44 | Relabel(13, 26), 45 | Relabel(12, 25), 46 | Relabel(11, 24), 47 | Relabel(10, 23), 48 | Relabel(9, 22), 49 | Relabel(8, 21), 50 | Relabel(7, 20), 51 | Relabel(6, 19), 52 | Relabel(5, 17), 53 | Relabel(4, 13), 54 | Relabel(3, 12), 55 | Relabel(2, 11), 56 | Relabel(1, 8), 57 | Relabel(0, 7), 58 | Relabel(255, 0), 59 | ToPILImage(), 60 | Resize(1024, Image.NEAREST), 61 | ]) 62 | 63 | def main(args): 64 | 65 | modelpath = args.loadDir + args.loadModel 66 | weightspath = args.loadDir + args.loadWeights 67 | 68 | print ("Loading model: " + modelpath) 69 | print ("Loading weights: " + weightspath) 70 | 71 | 72 | model = Net(NUM_CLASSES) 73 | 74 | model = torch.nn.DataParallel(model) 75 | if (not args.cpu): 76 | model = model.cuda() 77 | 78 | #model.load_state_dict(torch.load(args.state)) 79 | #model.load_state_dict(torch.load(weightspath)) #not working if missing key 80 | 81 | def load_my_state_dict(model, state_dict): #custom function to load model when not all dict elements 82 | own_state = model.state_dict() 83 | for name, param in state_dict.items(): 84 | if name not in own_state: 85 | continue 86 | own_state[name].copy_(param) 87 | return model 88 | 89 | model = load_my_state_dict(model, torch.load(weightspath)) 90 | print ("Model and weights LOADED successfully") 91 | 92 | model.eval() 93 | 94 | if(not os.path.exists(args.datadir)): 95 | print ("Error: datadir could not be loaded") 96 | 97 | 98 | loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset), 99 | num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) 100 | 101 | for step, (images, labels, filename, filenameGt) in enumerate(loader): 102 | if (not args.cpu): 103 | images = images.cuda() 104 | #labels = labels.cuda() 105 | 106 | inputs = Variable(images) 107 | #targets = Variable(labels) 108 | with torch.no_grad(): 109 | outputs = model(inputs) 110 | 111 | label = outputs[0].max(0)[1].byte().cpu().data 112 | label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0)) 113 | #print (numpy.unique(label.numpy())) #debug 114 | 115 | 116 | filenameSave = "./save_results/" + filename[0].split("leftImg8bit/")[1] 117 | 118 | os.makedirs(os.path.dirname(filenameSave), exist_ok=True) 119 | #image_transform(label.byte()).save(filenameSave) 120 | label_cityscapes.save(filenameSave) 121 | 122 | print (step, filenameSave) 123 | 124 | 125 | 126 | if __name__ == '__main__': 127 | parser = ArgumentParser() 128 | 129 | parser.add_argument('--state') 130 | 131 | parser.add_argument('--loadDir',default="../save/logs/") 132 | parser.add_argument('--loadWeights', default="model_best.pth") 133 | parser.add_argument('--loadModel', default="lednet.py") 134 | parser.add_argument('--subset', default="val") #can be val, test, train, demoSequence 135 | parser.add_argument('--datadir', default=os.getenv("HOME") + "/datasets/cityscapes/") 136 | parser.add_argument('--num-workers', type=int, default=4) 137 | parser.add_argument('--batch-size', type=int, default=1) 138 | parser.add_argument('--cpu', action='store_true') 139 | 140 | main(parser.parse_args()) 141 | -------------------------------------------------------------------------------- /test/eval_forward_time.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import time 5 | 6 | from PIL import Image 7 | from argparse import ArgumentParser 8 | 9 | from torch.autograd import Variable 10 | 11 | 12 | from lednet_no_bn import Net 13 | 14 | from transform import Relabel, ToLabel, Colorize 15 | 16 | import torch.backends.cudnn as cudnn 17 | cudnn.benchmark = True 18 | 19 | def main(args): 20 | model = Net(19) 21 | 22 | if (not args.cpu): 23 | model = model.cuda()#.half() #HALF seems to be doing slower for some reason 24 | #model = torch.nn.DataParallel(model).cuda() 25 | 26 | model.eval() 27 | 28 | 29 | images = torch.randn(args.batch_size, args.num_channels, args.height, args.width) 30 | 31 | if (not args.cpu): 32 | images = images.cuda()#.half() 33 | 34 | time_train = [] 35 | 36 | i=0 37 | 38 | while(True): 39 | #for step, (images, labels, filename, filenameGt) in enumerate(loader): 40 | 41 | start_time = time.time() 42 | 43 | inputs = Variable(images) 44 | with torch.no_grad(): 45 | outputs = model(inputs) 46 | 47 | #preds = outputs.cpu() 48 | if (not args.cpu): 49 | torch.cuda.synchronize() #wait for cuda to finish (cuda is asynchronous!) 50 | 51 | if i!=0: #first run always takes some time for setup 52 | fwt = time.time() - start_time 53 | time_train.append(fwt) 54 | print ("Forward time per img (b=%d): %.3f (Mean: %.3f)" % (args.batch_size, fwt/args.batch_size, sum(time_train) / len(time_train) / args.batch_size)) 55 | 56 | time.sleep(1) #to avoid overheating the GPU too much 57 | i+=1 58 | 59 | if __name__ == '__main__': 60 | parser = ArgumentParser() 61 | 62 | parser.add_argument('--width', type=int, default=1024) 63 | parser.add_argument('--height', type=int, default=512) 64 | parser.add_argument('--num-channels', type=int, default=3) 65 | parser.add_argument('--batch-size', type=int, default=1) 66 | parser.add_argument('--cpu', action='store_true') 67 | 68 | main(parser.parse_args()) 69 | -------------------------------------------------------------------------------- /test/eval_iou.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import os 5 | import importlib 6 | import time 7 | 8 | from PIL import Image 9 | from argparse import ArgumentParser 10 | 11 | from torch.autograd import Variable 12 | from torch.utils.data import DataLoader 13 | from torchvision.transforms import Compose, CenterCrop, Normalize, Resize 14 | from torchvision.transforms import ToTensor, ToPILImage 15 | 16 | from dataset import cityscapes 17 | 18 | from lednet.py import Net 19 | 20 | 21 | from transform import Relabel, ToLabel, Colorize 22 | from iouEval import iouEval, getColorEntry 23 | 24 | NUM_CHANNELS = 3 25 | NUM_CLASSES = 20 26 | 27 | image_transform = ToPILImage() 28 | input_transform_cityscapes = Compose([ 29 | Resize(512, Image.BILINEAR), 30 | ToTensor(), 31 | ]) 32 | target_transform_cityscapes = Compose([ 33 | Resize(512, Image.NEAREST), 34 | ToLabel(), 35 | Relabel(255, 19), #ignore label to 19 36 | ]) 37 | 38 | def main(args): 39 | 40 | modelpath = args.loadDir + args.loadModel 41 | weightspath = args.loadDir + args.loadWeights 42 | 43 | print ("Loading model: " + modelpath) 44 | print ("Loading weights: " + weightspath) 45 | 46 | model = Net(NUM_CLASSES) 47 | 48 | #model = torch.nn.DataParallel(model) 49 | if (not args.cpu): 50 | model = torch.nn.DataParallel(model).cuda() 51 | 52 | def load_my_state_dict(model, state_dict): #custom function to load model when not all dict elements 53 | own_state = model.state_dict() 54 | for name, param in state_dict.items(): 55 | if name not in own_state: 56 | if name.startswith("module."): 57 | own_state[name.split("module.")[-1]].copy_(param) 58 | else: 59 | print(name, " not loaded") 60 | continue 61 | else: 62 | own_state[name].copy_(param) 63 | return model 64 | 65 | model = load_my_state_dict(model, torch.load(weightspath, map_location=lambda storage, loc: storage)) 66 | print ("Model and weights LOADED successfully") 67 | 68 | 69 | model.eval() 70 | 71 | if(not os.path.exists(args.datadir)): 72 | print ("Error: datadir could not be loaded") 73 | 74 | 75 | loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset), num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) 76 | 77 | 78 | iouEvalVal = iouEval(NUM_CLASSES) 79 | 80 | start = time.time() 81 | 82 | for step, (images, labels, filename, filenameGt) in enumerate(loader): 83 | if (not args.cpu): 84 | images = images.cuda() 85 | labels = labels.cuda() 86 | 87 | inputs = Variable(images) 88 | with torch.no_grad(): 89 | outputs = model(inputs) 90 | 91 | iouEvalVal.addBatch(outputs.max(1)[1].unsqueeze(1).data, labels) 92 | 93 | filenameSave = filename[0].split("leftImg8bit/")[1] 94 | 95 | print (step, filenameSave) 96 | 97 | 98 | iouVal, iou_classes = iouEvalVal.getIoU() 99 | 100 | iou_classes_str = [] 101 | for i in range(iou_classes.size(0)): 102 | iouStr = getColorEntry(iou_classes[i])+'{:0.2f}'.format(iou_classes[i]*100) + '\033[0m' 103 | iou_classes_str.append(iouStr) 104 | 105 | print("---------------------------------------") 106 | print("Took ", time.time()-start, "seconds") 107 | print("=======================================") 108 | #print("TOTAL IOU: ", iou * 100, "%") 109 | print("Per-Class IoU:") 110 | print(iou_classes_str[0], "Road") 111 | print(iou_classes_str[1], "sidewalk") 112 | print(iou_classes_str[2], "building") 113 | print(iou_classes_str[3], "wall") 114 | print(iou_classes_str[4], "fence") 115 | print(iou_classes_str[5], "pole") 116 | print(iou_classes_str[6], "traffic light") 117 | print(iou_classes_str[7], "traffic sign") 118 | print(iou_classes_str[8], "vegetation") 119 | print(iou_classes_str[9], "terrain") 120 | print(iou_classes_str[10], "sky") 121 | print(iou_classes_str[11], "person") 122 | print(iou_classes_str[12], "rider") 123 | print(iou_classes_str[13], "car") 124 | print(iou_classes_str[14], "truck") 125 | print(iou_classes_str[15], "bus") 126 | print(iou_classes_str[16], "train") 127 | print(iou_classes_str[17], "motorcycle") 128 | print(iou_classes_str[18], "bicycle") 129 | print("=======================================") 130 | iouStr = getColorEntry(iouVal)+'{:0.2f}'.format(iouVal*100) + '\033[0m' 131 | print ("MEAN IoU: ", iouStr, "%") 132 | 133 | if __name__ == '__main__': 134 | parser = ArgumentParser() 135 | 136 | parser.add_argument('--state') 137 | 138 | parser.add_argument('--loadDir',default="../trained_models/") 139 | parser.add_argument('--loadWeights', default="lednet_trained.pth") 140 | parser.add_argument('--loadModel', default="lednet.py") 141 | parser.add_argument('--subset', default="val") #can be val or train (must have labels) 142 | parser.add_argument('--datadir', default=os.getenv("HOME") + "/datasets/cityscapes/") 143 | parser.add_argument('--num-workers', type=int, default=4) 144 | parser.add_argument('--batch-size', type=int, default=1) 145 | parser.add_argument('--cpu', action='store_true') 146 | 147 | main(parser.parse_args()) 148 | -------------------------------------------------------------------------------- /test/iouEval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class iouEval: 4 | 5 | def __init__(self, nClasses, ignoreIndex=19): 6 | self.nClasses = nClasses 7 | self.ignoreIndex = ignoreIndex if nClasses>ignoreIndex else -1 #if ignoreIndex is larger than nClasses, consider no ignoreIndex 8 | self.reset() 9 | 10 | def reset (self): 11 | classes = self.nClasses if self.ignoreIndex==-1 else self.nClasses-1 12 | self.tp = torch.zeros(classes).double() 13 | self.fp = torch.zeros(classes).double() 14 | self.fn = torch.zeros(classes).double() 15 | 16 | def addBatch(self, x, y): #x=preds, y=targets 17 | #sizes should be "batch_size x nClasses x H x W" 18 | 19 | #print ("X is cuda: ", x.is_cuda) 20 | #print ("Y is cuda: ", y.is_cuda) 21 | 22 | if (x.is_cuda or y.is_cuda): 23 | x = x.cuda() 24 | y = y.cuda() 25 | 26 | #if size is "batch_size x 1 x H x W" scatter to onehot 27 | if (x.size(1) == 1): 28 | x_onehot = torch.zeros(x.size(0), self.nClasses, x.size(2), x.size(3)) 29 | if x.is_cuda: 30 | x_onehot = x_onehot.cuda() 31 | x_onehot.scatter_(1, x, 1).float() 32 | else: 33 | x_onehot = x.float() 34 | 35 | if (y.size(1) == 1): 36 | y_onehot = torch.zeros(y.size(0), self.nClasses, y.size(2), y.size(3)) 37 | if y.is_cuda: 38 | y_onehot = y_onehot.cuda() 39 | y_onehot.scatter_(1, y, 1).float() 40 | else: 41 | y_onehot = y.float() 42 | 43 | if (self.ignoreIndex != -1): 44 | ignores = y_onehot[:,self.ignoreIndex].unsqueeze(1) 45 | x_onehot = x_onehot[:, :self.ignoreIndex] 46 | y_onehot = y_onehot[:, :self.ignoreIndex] 47 | else: 48 | ignores=0 49 | 50 | #print(type(x_onehot)) 51 | #print(type(y_onehot)) 52 | #print(x_onehot.size()) 53 | #print(y_onehot.size()) 54 | 55 | tpmult = x_onehot * y_onehot #times prediction and gt coincide is 1 56 | tp = torch.sum(torch.sum(torch.sum(tpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze() 57 | fpmult = x_onehot * (1-y_onehot-ignores) #times prediction says its that class and gt says its not (subtracting cases when its ignore label!) 58 | fp = torch.sum(torch.sum(torch.sum(fpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze() 59 | fnmult = (1-x_onehot) * (y_onehot) #times prediction says its not that class and gt says it is 60 | fn = torch.sum(torch.sum(torch.sum(fnmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze() 61 | 62 | self.tp += tp.double().cpu() 63 | self.fp += fp.double().cpu() 64 | self.fn += fn.double().cpu() 65 | 66 | def getIoU(self): 67 | num = self.tp 68 | den = self.tp + self.fp + self.fn + 1e-15 69 | iou = num / den 70 | return torch.mean(iou), iou #returns "iou mean", "iou per class" 71 | 72 | # Class for colors 73 | class colors: 74 | RED = '\033[31;1m' 75 | GREEN = '\033[32;1m' 76 | YELLOW = '\033[33;1m' 77 | BLUE = '\033[34;1m' 78 | MAGENTA = '\033[35;1m' 79 | CYAN = '\033[36;1m' 80 | BOLD = '\033[1m' 81 | UNDERLINE = '\033[4m' 82 | ENDC = '\033[0m' 83 | 84 | # Colored value output if colorized flag is activated. 85 | def getColorEntry(val): 86 | if not isinstance(val, float): 87 | return colors.ENDC 88 | if (val < .20): 89 | return colors.RED 90 | elif (val < .40): 91 | return colors.YELLOW 92 | elif (val < .60): 93 | return colors.BLUE 94 | elif (val < .80): 95 | return colors.CYAN 96 | else: 97 | return colors.GREEN 98 | 99 | -------------------------------------------------------------------------------- /test/lednet_no_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.nn.functional import interpolate as interpolate 6 | 7 | def split(x): 8 | c = int(x.size()[1]) 9 | c1 = round(c * 0.5) 10 | x1 = x[:, :c1, :, :].contiguous() 11 | x2 = x[:, c1:, :, :].contiguous() 12 | 13 | return x1, x2 14 | 15 | def channel_shuffle(x,groups): 16 | batchsize, num_channels, height, width = x.data.size() 17 | 18 | channels_per_group = num_channels // groups 19 | 20 | #reshape 21 | x = x.view(batchsize,groups, 22 | channels_per_group,height,width) 23 | 24 | x = torch.transpose(x,1,2).contiguous() 25 | 26 | #flatten 27 | x = x.view(batchsize,-1,height,width) 28 | 29 | return x 30 | 31 | 32 | class Conv2dBnRelu(nn.Module): 33 | def __init__(self,in_ch,out_ch,kernel_size=3,stride=1,padding=0,dilation=1,bias=True): 34 | super(Conv2dBnRelu,self).__init__() 35 | 36 | self.conv = nn.Sequential( 37 | nn.Conv2d(in_ch,out_ch,kernel_size,stride,padding,dilation=dilation,bias=bias), 38 | # nn.BatchNorm2d(out_ch, eps=1e-3), 39 | nn.ReLU(inplace=True) 40 | ) 41 | 42 | def forward(self, x): 43 | return self.conv(x) 44 | 45 | 46 | 47 | class DownsamplerBlock (nn.Module): 48 | def __init__(self, in_channel, out_channel): 49 | super(DownsamplerBlock,self).__init__() 50 | 51 | self.conv = nn.Conv2d(in_channel, out_channel-in_channel, (3, 3), stride=2, padding=1, bias=True) 52 | self.pool = nn.MaxPool2d(2, stride=2) 53 | self.bn = nn.BatchNorm2d(out_channel, eps=1e-3) 54 | self.relu = nn.ReLU(inplace=True) 55 | 56 | def forward(self, input): 57 | x1 = self.pool(input) 58 | x2 = self.conv(input) 59 | 60 | diffY = x2.size()[2] - x1.size()[2] 61 | diffX = x2.size()[3] - x1.size()[3] 62 | 63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 64 | diffY // 2, diffY - diffY // 2]) 65 | 66 | output = torch.cat([x2, x1], 1) 67 | output = self.bn(output) 68 | output = self.relu(output) 69 | return output 70 | 71 | 72 | class SS_nbt_module(nn.Module): 73 | def __init__(self, chann, dropprob, dilated): 74 | super(SS_nbt_module,self).__init__() 75 | 76 | oup_inc = chann//2 77 | 78 | # 79 | self.conv3x1_1_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True) 80 | 81 | self.conv1x3_1_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True) 82 | 83 | self.bn1_l = nn.BatchNorm2d(oup_inc, eps=1e-03) 84 | 85 | self.conv3x1_2_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1)) 86 | 87 | self.conv1x3_2_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated)) 88 | 89 | self.bn2_l = nn.BatchNorm2d(oup_inc, eps=1e-03) 90 | 91 | # 92 | self.conv3x1_1_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True) 93 | 94 | self.conv1x3_1_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True) 95 | 96 | self.bn1_r = nn.BatchNorm2d(oup_inc, eps=1e-03) 97 | 98 | self.conv3x1_2_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1)) 99 | 100 | self.conv1x3_2_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated)) 101 | 102 | self.bn2_r = nn.BatchNorm2d(oup_inc, eps=1e-03) 103 | 104 | self.relu = nn.ReLU(inplace=True) 105 | self.dropout = nn.Dropout2d(dropprob) 106 | 107 | @staticmethod 108 | def _concat(x,out): 109 | return torch.cat((x,out),1) 110 | 111 | def forward(self, input): 112 | 113 | # x1 = input[:,:(input.shape[1]//2),:,:] 114 | # x2 = input[:,(input.shape[1]//2):,:,:] 115 | residual = input 116 | x1, x2 = split(input) 117 | 118 | output1 = self.conv3x1_1_l(x1) 119 | output1 = self.relu(output1) 120 | output1 = self.conv1x3_1_l(output1) 121 | # output1 = self.bn1_l(output1) 122 | output1 = self.relu(output1) 123 | 124 | output1 = self.conv3x1_2_l(output1) 125 | output1 = self.relu(output1) 126 | output1 = self.conv1x3_2_l(output1) 127 | # output1 = self.bn2_l(output1) 128 | 129 | 130 | output2 = self.conv1x3_1_r(x2) 131 | output2 = self.relu(output2) 132 | output2 = self.conv3x1_1_r(output2) 133 | # output2 = self.bn1_r(output2) 134 | output2 = self.relu(output2) 135 | 136 | output2 = self.conv1x3_2_r(output2) 137 | output2 = self.relu(output2) 138 | output2 = self.conv3x1_2_r(output2) 139 | # output2 = self.bn2_r(output2) 140 | 141 | # if (self.dropout.p != 0): 142 | #output1 = self.dropout(output1) 143 | #output2 = self.dropout(output2) 144 | 145 | out = self._concat(output1,output2) 146 | out = F.relu(residual + out, inplace=True) 147 | return channel_shuffle(out, 2) 148 | 149 | class Encoder(nn.Module): 150 | def __init__(self, num_classes): 151 | super().__init__() 152 | self.initial_block = DownsamplerBlock(3,32) 153 | 154 | 155 | self.layers = nn.ModuleList() 156 | 157 | for x in range(0, 3): 158 | self.layers.append(SS_nbt_module(32, 0.03, 1)) 159 | 160 | 161 | self.layers.append(DownsamplerBlock(32,64)) 162 | 163 | 164 | for x in range(0, 2): 165 | self.layers.append(SS_nbt_module(64, 0.03, 1)) 166 | 167 | self.layers.append(DownsamplerBlock(64,128)) 168 | 169 | for x in range(0, 1): 170 | self.layers.append(SS_nbt_module(128, 0.3, 1)) 171 | self.layers.append(SS_nbt_module(128, 0.3, 2)) 172 | self.layers.append(SS_nbt_module(128, 0.3, 5)) 173 | self.layers.append(SS_nbt_module(128, 0.3, 9)) 174 | 175 | for x in range(0, 1): 176 | self.layers.append(SS_nbt_module(128, 0.3, 2)) 177 | self.layers.append(SS_nbt_module(128, 0.3, 5)) 178 | self.layers.append(SS_nbt_module(128, 0.3, 9)) 179 | self.layers.append(SS_nbt_module(128, 0.3, 17)) 180 | 181 | 182 | # Only in encoder mode: 183 | self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True) 184 | 185 | def forward(self, input, predict=False): 186 | 187 | output = self.initial_block(input) 188 | 189 | for layer in self.layers: 190 | output = layer(output) 191 | 192 | if predict: 193 | output = self.output_conv(output) 194 | 195 | return output 196 | 197 | class Interpolate(nn.Module): 198 | def __init__(self,size,mode): 199 | super(Interpolate,self).__init__() 200 | 201 | self.interp = nn.functional.interpolate 202 | self.size = size 203 | self.mode = mode 204 | def forward(self,x): 205 | x = self.interp(x,size=self.size,mode=self.mode,align_corners=True) 206 | return x 207 | 208 | 209 | class APN_Module(nn.Module): 210 | def __init__(self, in_ch, out_ch): 211 | super(APN_Module, self).__init__() 212 | # global pooling branch 213 | self.branch1 = nn.Sequential( 214 | nn.AdaptiveAvgPool2d(1), 215 | Conv2dBnRelu(in_ch, out_ch, kernel_size=1, stride=1, padding=0) 216 | ) 217 | 218 | # midddle branch 219 | self.mid = nn.Sequential( 220 | Conv2dBnRelu(in_ch, out_ch, kernel_size=1, stride=1, padding=0) 221 | ) 222 | 223 | self.down1 = Conv2dBnRelu(in_ch, 1, kernel_size=7, stride=2, padding=3) 224 | 225 | self.down2 = Conv2dBnRelu(1, 1, kernel_size=5, stride=2, padding=2) 226 | 227 | self.down3 = nn.Sequential( 228 | Conv2dBnRelu(1, 1, kernel_size=3, stride=2, padding=1), 229 | Conv2dBnRelu(1, 1, kernel_size=3, stride=1, padding=1) 230 | ) 231 | 232 | self.conv2 = Conv2dBnRelu(1, 1, kernel_size=5, stride=1, padding=2) 233 | self.conv1 = Conv2dBnRelu(1, 1, kernel_size=7, stride=1, padding=3) 234 | 235 | def forward(self, x): 236 | 237 | h = x.size()[2] 238 | w = x.size()[3] 239 | 240 | b1 = self.branch1(x) 241 | # b1 = Interpolate(size=(h, w), mode="bilinear")(b1) 242 | b1= interpolate(b1, size=(h, w), mode="bilinear", align_corners=True) 243 | 244 | mid = self.mid(x) 245 | 246 | x1 = self.down1(x) 247 | x2 = self.down2(x1) 248 | x3 = self.down3(x2) 249 | # x3 = Interpolate(size=(h // 4, w // 4), mode="bilinear")(x3) 250 | x3= interpolate(x3, size=(h // 4, w // 4), mode="bilinear", align_corners=True) 251 | x2 = self.conv2(x2) 252 | x = x2 + x3 253 | # x = Interpolate(size=(h // 2, w // 2), mode="bilinear")(x) 254 | x= interpolate(x, size=(h // 2, w // 2), mode="bilinear", align_corners=True) 255 | 256 | x1 = self.conv1(x1) 257 | x = x + x1 258 | # x = Interpolate(size=(h, w), mode="bilinear")(x) 259 | x= interpolate(x, size=(h, w), mode="bilinear", align_corners=True) 260 | 261 | x = torch.mul(x, mid) 262 | 263 | x = x + b1 264 | 265 | 266 | return x 267 | 268 | 269 | 270 | 271 | class Decoder (nn.Module): 272 | def __init__(self, num_classes): 273 | super().__init__() 274 | 275 | self.apn = APN_Module(in_ch=128,out_ch=20) 276 | # self.upsample = Interpolate(size=(512, 1024), mode="bilinear") 277 | # self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=4, stride=2, padding=1, output_padding=0, bias=True) 278 | # self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True) 279 | # self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True) 280 | 281 | def forward(self, input): 282 | 283 | output = self.apn(input) 284 | out = interpolate(output, size=(512, 1024), mode="bilinear", align_corners=True) 285 | # out = self.upsample(output) 286 | return out 287 | 288 | 289 | # LEDNet 290 | class Net(nn.Module): 291 | def __init__(self, num_classes, encoder=None): 292 | super().__init__() 293 | 294 | if (encoder == None): 295 | self.encoder = Encoder(num_classes) 296 | else: 297 | self.encoder = encoder 298 | self.decoder = Decoder(num_classes) 299 | 300 | def forward(self, input, only_encode=False): 301 | if only_encode: 302 | return self.encoder.forward(input, predict=True) 303 | else: 304 | output = self.encoder(input) 305 | return self.decoder.forward(output) 306 | 307 | 308 | -------------------------------------------------------------------------------- /test/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from PIL import Image 5 | 6 | def colormap_cityscapes(n): 7 | cmap=np.zeros([n, 3]).astype(np.uint8) 8 | cmap[0,:] = np.array([128, 64,128]) 9 | cmap[1,:] = np.array([244, 35,232]) 10 | cmap[2,:] = np.array([ 70, 70, 70]) 11 | cmap[3,:] = np.array([ 102,102,156]) 12 | cmap[4,:] = np.array([ 190,153,153]) 13 | cmap[5,:] = np.array([ 153,153,153]) 14 | 15 | cmap[6,:] = np.array([ 250,170, 30]) 16 | cmap[7,:] = np.array([ 220,220, 0]) 17 | cmap[8,:] = np.array([ 107,142, 35]) 18 | cmap[9,:] = np.array([ 152,251,152]) 19 | cmap[10,:] = np.array([ 70,130,180]) 20 | 21 | cmap[11,:] = np.array([ 220, 20, 60]) 22 | cmap[12,:] = np.array([ 255, 0, 0]) 23 | cmap[13,:] = np.array([ 0, 0,142]) 24 | cmap[14,:] = np.array([ 0, 0, 70]) 25 | cmap[15,:] = np.array([ 0, 60,100]) 26 | 27 | cmap[16,:] = np.array([ 0, 80,100]) 28 | cmap[17,:] = np.array([ 0, 0,230]) 29 | cmap[18,:] = np.array([ 119, 11, 32]) 30 | cmap[19,:] = np.array([ 0, 0, 0]) 31 | 32 | return cmap 33 | 34 | 35 | def colormap(n): 36 | cmap=np.zeros([n, 3]).astype(np.uint8) 37 | 38 | for i in np.arange(n): 39 | r, g, b = np.zeros(3) 40 | 41 | for j in np.arange(8): 42 | r = r + (1<<(7-j))*((i&(1<<(3*j))) >> (3*j)) 43 | g = g + (1<<(7-j))*((i&(1<<(3*j+1))) >> (3*j+1)) 44 | b = b + (1<<(7-j))*((i&(1<<(3*j+2))) >> (3*j+2)) 45 | 46 | cmap[i,:] = np.array([r, g, b]) 47 | 48 | return cmap 49 | 50 | class Relabel: 51 | 52 | def __init__(self, olabel, nlabel): 53 | self.olabel = olabel 54 | self.nlabel = nlabel 55 | 56 | def __call__(self, tensor): 57 | assert isinstance(tensor, torch.LongTensor) or isinstance(tensor, torch.ByteTensor) , 'tensor needs to be LongTensor' 58 | tensor[tensor == self.olabel] = self.nlabel 59 | return tensor 60 | 61 | 62 | class ToLabel: 63 | 64 | def __call__(self, image): 65 | return torch.from_numpy(np.array(image)).long().unsqueeze(0) 66 | 67 | 68 | class Colorize: 69 | 70 | def __init__(self, n=22): 71 | #self.cmap = colormap(256) 72 | self.cmap = colormap_cityscapes(256) 73 | self.cmap[n] = self.cmap[-1] 74 | self.cmap = torch.from_numpy(self.cmap[:n]) 75 | 76 | def __call__(self, gray_image): 77 | size = gray_image.size() 78 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 79 | 80 | #for label in range(1, len(self.cmap)): 81 | for label in range(0, len(self.cmap)): 82 | mask = gray_image[0] == label 83 | 84 | color_image[0][mask] = self.cmap[label][0] 85 | color_image[1][mask] = self.cmap[label][1] 86 | color_image[2][mask] = self.cmap[label][2] 87 | 88 | return color_image 89 | -------------------------------------------------------------------------------- /train/README.md: -------------------------------------------------------------------------------- 1 | #### Training LEDNet in Pytorch 2 | 3 | PyTorch code for training LEDNet model on Cityscapes. The code was based initially on the code from [bodokaiser/piwise](https://github.com/bodokaiser/piwise), adapted with several custom added modifications and tweaks. Some of them are: 4 | - Load cityscapes dataset 5 | - LEDNet model definition 6 | - Calculate IoU on each epoch during training 7 | - Save snapshots and best model during training 8 | - Save additional output files useful for checking results (see below "Output files...") 9 | - Resume training from checkpoint (use "--resume" flag in the command) 10 | 11 | #### Options 12 | For all options and defaults please see the bottom of the "main.py" file. Required ones are --savedir (name for creating a new folder with all the outputs of the training) and --datadir (path to cityscapes directory). 13 | 14 | #### Example commands 15 | Train encoder with 300+ epochs and batch=5 and then train decoder (decoder training starts after encoder training): for example 16 | ``` 17 | python main.py --savedir logs --datadir /home/datasets/cityscapes/ --num-epochs 300 --batch-size 5 ... 18 | ``` 19 | 20 | Each training will create a new folder in the "LEDNet_master/save/" directory named with the parameter --savedir and the following files: 21 | * **{model}.py**: copy of the model file used (default lednet.py). 22 | * **model.txt**: Plain text that displays the model's layers 23 | * **model_best.pth**: saved weights of the epoch that achieved best val accuracy. 24 | * **model_best.pth.tar**: Same parameters as "checkpoint.pth.tar" but for the epoch with best val accuracy. 25 | * **opts.txt**: Plain text file containing the options used for this training 26 | * **automated_log.txt**: Plain text file that contains in columns the following info of each epoch {Epoch, Train-loss,Test-loss,Train-IoU,Test-IoU, learningRate}. Can be used to plot using Gnuplot or Excel or Matplotlib. 27 | * **best.txt**: Plain text file containing a line with the best IoU achieved during training and its epoch. 28 | * **checkpoint.pth.tar**: bundle file that contains the checkpoint of the last trained epoch, contains the following elements: 'epoch' (epoch number as int), 'arch' (net definition as a string), 'state_dict' (saved weights dictionary loadable by pytorch), 'best_acc' (best achieved accuracy as float), 'optimizer' (saved optimizer parameters). 29 | 30 | NOTE: Encoder trainings have an added "_encoder" tag to each file's name. 31 | 32 | #### IoU display during training 33 | 34 | NEW: In previous code, IoU was calculated using a port of the cityscapes scripts, but new code has been added in "iouEval.py" to make it class-general, non-dependable on other code, and much faster (using cuda) 35 | 36 | By default, only Validation IoU is calculated for faster training (can be changed in options) 37 | 38 | #### Visualization 39 | If you want to visualize the outputs during training add the "--visualize" flag and open an extra tab with: 40 | ``` 41 | python -m visdom.server -port 8097 42 | ``` 43 | The plots will be available using the browser in http://localhost.com:8097 44 | 45 | #### Multi-GPU 46 | If you wish to specify which GPUs to use, use the CUDA_VISIBLE_DEVICES command: 47 | ``` 48 | CUDA_VISIBLE_DEVICES=0 python main.py ... 49 | CUDA_VISIBLE_DEVICES=0,1 python main.py ... 50 | ``` 51 | 52 | 53 | -------------------------------------------------------------------------------- /train/lednet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.nn.functional import interpolate as interpolate 6 | 7 | 8 | def split(x): 9 | c = int(x.size()[1]) 10 | c1 = round(c * 0.5) 11 | x1 = x[:, :c1, :, :].contiguous() 12 | x2 = x[:, c1:, :, :].contiguous() 13 | 14 | return x1, x2 15 | 16 | def channel_shuffle(x,groups): 17 | batchsize, num_channels, height, width = x.data.size() 18 | 19 | channels_per_group = num_channels // groups 20 | 21 | # reshape 22 | x = x.view(batchsize,groups, 23 | channels_per_group,height,width) 24 | 25 | x = torch.transpose(x,1,2).contiguous() 26 | 27 | # flatten 28 | x = x.view(batchsize,-1,height,width) 29 | 30 | return x 31 | 32 | 33 | class Conv2dBnRelu(nn.Module): 34 | def __init__(self,in_ch,out_ch,kernel_size=3,stride=1,padding=0,dilation=1,bias=True): 35 | super(Conv2dBnRelu,self).__init__() 36 | 37 | self.conv = nn.Sequential( 38 | nn.Conv2d(in_ch,out_ch,kernel_size,stride,padding,dilation=dilation,bias=bias), 39 | nn.BatchNorm2d(out_ch, eps=1e-3), 40 | nn.ReLU(inplace=True) 41 | ) 42 | 43 | def forward(self, x): 44 | return self.conv(x) 45 | 46 | 47 | # after Concat -> BN, you also can use Dropout like SS_nbt_module may be make a good result! 48 | class DownsamplerBlock (nn.Module): 49 | def __init__(self, in_channel, out_channel): 50 | super(DownsamplerBlock,self).__init__() 51 | 52 | self.conv = nn.Conv2d(in_channel, out_channel-in_channel, (3, 3), stride=2, padding=1, bias=True) 53 | self.pool = nn.MaxPool2d(2, stride=2) 54 | self.bn = nn.BatchNorm2d(out_channel, eps=1e-3) 55 | self.relu = nn.ReLU(inplace=True) 56 | 57 | def forward(self, input): 58 | x1 = self.pool(input) 59 | x2 = self.conv(input) 60 | 61 | diffY = x2.size()[2] - x1.size()[2] 62 | diffX = x2.size()[3] - x1.size()[3] 63 | 64 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 65 | diffY // 2, diffY - diffY // 2]) 66 | 67 | output = torch.cat([x2, x1], 1) 68 | output = self.bn(output) 69 | output = self.relu(output) 70 | return output 71 | 72 | 73 | class SS_nbt_module(nn.Module): 74 | def __init__(self, chann, dropprob, dilated): 75 | super().__init__() 76 | 77 | oup_inc = chann//2 78 | 79 | # dw 80 | self.conv3x1_1_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True) 81 | 82 | self.conv1x3_1_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True) 83 | 84 | self.bn1_l = nn.BatchNorm2d(oup_inc, eps=1e-03) 85 | 86 | self.conv3x1_2_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1)) 87 | 88 | self.conv1x3_2_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated)) 89 | 90 | self.bn2_l = nn.BatchNorm2d(oup_inc, eps=1e-03) 91 | 92 | # dw 93 | self.conv3x1_1_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True) 94 | 95 | self.conv1x3_1_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True) 96 | 97 | self.bn1_r = nn.BatchNorm2d(oup_inc, eps=1e-03) 98 | 99 | self.conv3x1_2_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1)) 100 | 101 | self.conv1x3_2_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated)) 102 | 103 | self.bn2_r = nn.BatchNorm2d(oup_inc, eps=1e-03) 104 | 105 | self.relu = nn.ReLU(inplace=True) 106 | self.dropout = nn.Dropout2d(dropprob) 107 | 108 | @staticmethod 109 | def _concat(x,out): 110 | return torch.cat((x,out),1) 111 | 112 | def forward(self, input): 113 | 114 | # x1 = input[:,:(input.shape[1]//2),:,:] 115 | # x2 = input[:,(input.shape[1]//2):,:,:] 116 | residual = input 117 | x1, x2 = split(input) 118 | 119 | output1 = self.conv3x1_1_l(x1) 120 | output1 = self.relu(output1) 121 | output1 = self.conv1x3_1_l(output1) 122 | output1 = self.bn1_l(output1) 123 | output1 = self.relu(output1) 124 | 125 | output1 = self.conv3x1_2_l(output1) 126 | output1 = self.relu(output1) 127 | output1 = self.conv1x3_2_l(output1) 128 | output1 = self.bn2_l(output1) 129 | 130 | 131 | output2 = self.conv1x3_1_r(x2) 132 | output2 = self.relu(output2) 133 | output2 = self.conv3x1_1_r(output2) 134 | output2 = self.bn1_r(output2) 135 | output2 = self.relu(output2) 136 | 137 | output2 = self.conv1x3_2_r(output2) 138 | output2 = self.relu(output2) 139 | output2 = self.conv3x1_2_r(output2) 140 | output2 = self.bn2_r(output2) 141 | 142 | if (self.dropout.p != 0): 143 | output1 = self.dropout(output1) 144 | output2 = self.dropout(output2) 145 | 146 | out = self._concat(output1,output2) 147 | out = F.relu(residual + out, inplace=True) 148 | return channel_shuffle(out,2) 149 | 150 | 151 | 152 | class Encoder(nn.Module): 153 | def __init__(self, num_classes): 154 | super().__init__() 155 | 156 | self.initial_block = DownsamplerBlock(3,32) 157 | 158 | self.layers = nn.ModuleList() 159 | 160 | for x in range(0, 3): 161 | self.layers.append(SS_nbt_module(32, 0.03, 1)) 162 | 163 | 164 | self.layers.append(DownsamplerBlock(32,64)) 165 | 166 | 167 | for x in range(0, 2): 168 | self.layers.append(SS_nbt_module(64, 0.03, 1)) 169 | 170 | self.layers.append(DownsamplerBlock(64,128)) 171 | 172 | for x in range(0, 1): 173 | self.layers.append(SS_nbt_module(128, 0.3, 1)) 174 | self.layers.append(SS_nbt_module(128, 0.3, 2)) 175 | self.layers.append(SS_nbt_module(128, 0.3, 5)) 176 | self.layers.append(SS_nbt_module(128, 0.3, 9)) 177 | 178 | for x in range(0, 1): 179 | self.layers.append(SS_nbt_module(128, 0.3, 2)) 180 | self.layers.append(SS_nbt_module(128, 0.3, 5)) 181 | self.layers.append(SS_nbt_module(128, 0.3, 9)) 182 | self.layers.append(SS_nbt_module(128, 0.3, 17)) 183 | 184 | 185 | # Only in encoder mode: 186 | self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True) 187 | 188 | def forward(self, input, predict=False): 189 | 190 | output = self.initial_block(input) 191 | 192 | for layer in self.layers: 193 | output = layer(output) 194 | 195 | if predict: 196 | output = self.output_conv(output) 197 | 198 | return output 199 | 200 | class Interpolate(nn.Module): 201 | def __init__(self,size,mode): 202 | super(Interpolate,self).__init__() 203 | 204 | self.interp = nn.functional.interpolate 205 | self.size = size 206 | self.mode = mode 207 | def forward(self,x): 208 | x = self.interp(x,size=self.size,mode=self.mode,align_corners=True) 209 | return x 210 | 211 | 212 | class APN_Module(nn.Module): 213 | def __init__(self, in_ch, out_ch): 214 | super(APN_Module, self).__init__() 215 | # global pooling branch 216 | self.branch1 = nn.Sequential( 217 | nn.AdaptiveAvgPool2d(1), 218 | Conv2dBnRelu(in_ch, out_ch, kernel_size=1, stride=1, padding=0) 219 | ) 220 | # midddle branch 221 | self.mid = nn.Sequential( 222 | Conv2dBnRelu(in_ch, out_ch, kernel_size=1, stride=1, padding=0) 223 | ) 224 | self.down1 = Conv2dBnRelu(in_ch, 1, kernel_size=7, stride=2, padding=3) 225 | 226 | self.down2 = Conv2dBnRelu(1, 1, kernel_size=5, stride=2, padding=2) 227 | 228 | self.down3 = nn.Sequential( 229 | Conv2dBnRelu(1, 1, kernel_size=3, stride=2, padding=1), 230 | Conv2dBnRelu(1, 1, kernel_size=3, stride=1, padding=1) 231 | ) 232 | 233 | self.conv2 = Conv2dBnRelu(1, 1, kernel_size=5, stride=1, padding=2) 234 | self.conv1 = Conv2dBnRelu(1, 1, kernel_size=7, stride=1, padding=3) 235 | 236 | def forward(self, x): 237 | 238 | h = x.size()[2] 239 | w = x.size()[3] 240 | 241 | b1 = self.branch1(x) 242 | # b1 = Interpolate(size=(h, w), mode="bilinear")(b1) 243 | b1= interpolate(b1, size=(h, w), mode="bilinear", align_corners=True) 244 | 245 | mid = self.mid(x) 246 | 247 | x1 = self.down1(x) 248 | x2 = self.down2(x1) 249 | x3 = self.down3(x2) 250 | # x3 = Interpolate(size=(h // 4, w // 4), mode="bilinear")(x3) 251 | x3= interpolate(x3, size=(h // 4, w // 4), mode="bilinear", align_corners=True) 252 | x2 = self.conv2(x2) 253 | x = x2 + x3 254 | # x = Interpolate(size=(h // 2, w // 2), mode="bilinear")(x) 255 | x= interpolate(x, size=(h // 2, w // 2), mode="bilinear", align_corners=True) 256 | 257 | x1 = self.conv1(x1) 258 | x = x + x1 259 | # x = Interpolate(size=(h, w), mode="bilinear")(x) 260 | x= interpolate(x, size=(h, w), mode="bilinear", align_corners=True) 261 | 262 | x = torch.mul(x, mid) 263 | 264 | x = x + b1 265 | 266 | 267 | return x 268 | 269 | 270 | 271 | 272 | class Decoder (nn.Module): 273 | def __init__(self, num_classes): 274 | super().__init__() 275 | 276 | self.apn = APN_Module(in_ch=128,out_ch=20) 277 | # self.upsample = Interpolate(size=(512, 1024), mode="bilinear") 278 | # self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=4, stride=2, padding=1, output_padding=0, bias=True) 279 | # self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True) 280 | # self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True) 281 | 282 | def forward(self, input): 283 | 284 | output = self.apn(input) 285 | out = interpolate(output, size=(512, 1024), mode="bilinear", align_corners=True) 286 | # out = self.upsample(output) 287 | # print(out.shape) 288 | return out 289 | 290 | 291 | # LEDNet 292 | class Net(nn.Module): 293 | def __init__(self, num_classes, encoder=None): 294 | super().__init__() 295 | 296 | if (encoder == None): 297 | self.encoder = Encoder(num_classes) 298 | else: 299 | self.encoder = encoder 300 | self.decoder = Decoder(num_classes) 301 | 302 | def forward(self, input, only_encode=False): 303 | if only_encode: 304 | return self.encoder.forward(input, predict=True) 305 | else: 306 | output = self.encoder(input) 307 | return self.decoder.forward(output) 308 | -------------------------------------------------------------------------------- /train/lednet_1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.nn.functional import interpolate as interpolate 6 | 7 | def split(x): 8 | c = int(x.size()[1]) 9 | c1 = round(c * 0.5) 10 | x1 = x[:, :c1, :, :].contiguous() 11 | x2 = x[:, c1:, :, :].contiguous() 12 | 13 | return x1, x2 14 | 15 | def channel_shuffle(x,groups): 16 | batchsize, num_channels, height, width = x.data.size() 17 | 18 | channels_per_group = num_channels // groups 19 | 20 | #reshape 21 | x = x.view(batchsize,groups, 22 | channels_per_group,height,width) 23 | 24 | x = torch.transpose(x,1,2).contiguous() 25 | 26 | #flatten 27 | x = x.view(batchsize,-1,height,width) 28 | 29 | return x 30 | 31 | 32 | class Conv2dBnRelu(nn.Module): 33 | def __init__(self,in_ch,out_ch,kernel_size=3,stride=1,padding=0,dilation=1,bias=True): 34 | super(Conv2dBnRelu,self).__init__() 35 | 36 | self.conv = nn.Sequential( 37 | nn.Conv2d(in_ch,out_ch,kernel_size,stride,padding,dilation=dilation,bias=bias), 38 | nn.BatchNorm2d(out_ch, eps=1e-3), 39 | nn.ReLU(inplace=True) 40 | ) 41 | 42 | def forward(self, x): 43 | return self.conv(x) 44 | 45 | 46 | ##after Concat -> BN, you also can use Dropout like SS_nbt_module may be make a good result! 47 | class DownsamplerBlock (nn.Module): 48 | def __init__(self, in_channel, out_channel): 49 | super(DownsamplerBlock,self).__init__() 50 | 51 | self.conv = nn.Conv2d(in_channel, out_channel-in_channel, (3, 3), stride=2, padding=1, bias=True) 52 | self.pool = nn.MaxPool2d(2, stride=2) 53 | self.bn = nn.BatchNorm2d(out_channel, eps=1e-3) 54 | self.relu = nn.ReLU(inplace=True) 55 | 56 | def forward(self, input): 57 | x1 = self.pool(input) 58 | x2 = self.conv(input) 59 | 60 | diffY = x2.size()[2] - x1.size()[2] 61 | diffX = x2.size()[3] - x1.size()[3] 62 | 63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 64 | diffY // 2, diffY - diffY // 2]) 65 | 66 | output = torch.cat([x2, x1], 1) 67 | output = self.bn(output) 68 | output = self.relu(output) 69 | return output 70 | 71 | 72 | class SS_nbt_module(nn.Module): 73 | def __init__(self, chann, dropprob, dilated): 74 | super(SS_nbt_module,self).__init__() 75 | 76 | oup_inc = chann//2 77 | 78 | #dw 79 | self.conv3x1_1_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True) 80 | 81 | self.conv1x3_1_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True) 82 | 83 | self.bn1_l = nn.BatchNorm2d(oup_inc, eps=1e-03) 84 | 85 | self.conv3x1_2_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1)) 86 | 87 | self.conv1x3_2_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated)) 88 | 89 | self.bn2_l = nn.BatchNorm2d(oup_inc, eps=1e-03) 90 | 91 | #dw 92 | self.conv3x1_1_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True) 93 | 94 | self.conv1x3_1_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True) 95 | 96 | self.bn1_r = nn.BatchNorm2d(oup_inc, eps=1e-03) 97 | 98 | self.conv3x1_2_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1)) 99 | 100 | self.conv1x3_2_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated)) 101 | 102 | self.bn2_r = nn.BatchNorm2d(oup_inc, eps=1e-03) 103 | 104 | self.relu = nn.ReLU(inplace=True) 105 | self.dropout = nn.Dropout2d(dropprob) 106 | 107 | @staticmethod 108 | def _concat(x,out): 109 | return torch.cat((x,out),1) 110 | 111 | def forward(self, input): 112 | 113 | # x1 = input[:,:(input.shape[1]//2),:,:] 114 | # x2 = input[:,(input.shape[1]//2):,:,:] 115 | 116 | residual = input 117 | x1, x2 = split(input) 118 | 119 | output1 = self.conv3x1_1_l(x1) 120 | output1 = self.relu(output1) 121 | output1 = self.conv1x3_1_l(output1) 122 | output1 = self.bn1_l(output1) 123 | output1 = self.relu(output1) 124 | 125 | output1 = self.conv3x1_2_l(output1) 126 | output1 = self.relu(output1) 127 | output1 = self.conv1x3_2_l(output1) 128 | output1 = self.bn2_l(output1) 129 | 130 | output2 = self.conv1x3_1_r(x2) 131 | output2 = self.relu(output2) 132 | output2 = self.conv3x1_1_r(output2) 133 | output2 = self.bn1_r(output2) 134 | output2 = self.relu(output2) 135 | 136 | output2 = self.conv1x3_2_r(output2) 137 | output2 = self.relu(output2) 138 | output2 = self.conv3x1_2_r(output2) 139 | output2 = self.bn2_r(output2) 140 | 141 | if (self.dropout.p != 0): 142 | output1 = self.dropout(output1) 143 | output2 = self.dropout(output2) 144 | 145 | out = self._concat(output1,output2) 146 | out = F.relu(residual + out, inplace=True) 147 | 148 | return channel_shuffle(out, 2) 149 | 150 | 151 | class Encoder(nn.Module): 152 | def __init__(self, num_classes): 153 | super().__init__() 154 | self.initial_block = DownsamplerBlock(3,32) 155 | 156 | 157 | self.layers = nn.ModuleList() 158 | 159 | for x in range(0, 3): 160 | self.layers.append(SS_nbt_module(32, 0.03, 1)) 161 | 162 | 163 | self.layers.append(DownsamplerBlock(32,64)) 164 | 165 | 166 | for x in range(0, 2): 167 | self.layers.append(SS_nbt_module(64, 0.03, 1)) 168 | 169 | self.layers.append(DownsamplerBlock(64,128)) 170 | 171 | for x in range(0, 1): 172 | self.layers.append(SS_nbt_module(128, 0.3, 1)) 173 | self.layers.append(SS_nbt_module(128, 0.3, 2)) 174 | self.layers.append(SS_nbt_module(128, 0.3, 5)) 175 | self.layers.append(SS_nbt_module(128, 0.3, 9)) 176 | 177 | for x in range(0, 1): 178 | self.layers.append(SS_nbt_module(128, 0.3, 2)) 179 | self.layers.append(SS_nbt_module(128, 0.3, 5)) 180 | self.layers.append(SS_nbt_module(128, 0.3, 9)) 181 | self.layers.append(SS_nbt_module(128, 0.3, 17)) 182 | 183 | 184 | #Only in encoder mode: 185 | self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True) 186 | 187 | def forward(self, input, predict=False): 188 | 189 | output = self.initial_block(input) 190 | 191 | for layer in self.layers: 192 | output = layer(output) 193 | 194 | if predict: 195 | output = self.output_conv(output) 196 | 197 | return output 198 | 199 | class Interpolate(nn.Module): 200 | def __init__(self,size,mode): 201 | super(Interpolate,self).__init__() 202 | 203 | self.interp = nn.functional.interpolate 204 | self.size = size 205 | self.mode = mode 206 | def forward(self,x): 207 | x = self.interp(x,size=self.size,mode=self.mode,align_corners=True) 208 | return x 209 | 210 | 211 | class APN_Module(nn.Module): 212 | def __init__(self, in_ch, out_ch): 213 | super(APN_Module, self).__init__() 214 | # global pooling branch 215 | self.branch1 = nn.Sequential( 216 | nn.AdaptiveAvgPool2d(1), 217 | Conv2dBnRelu(in_ch, out_ch, kernel_size=1, stride=1, padding=0) 218 | ) 219 | 220 | # midddle branch 221 | self.mid = nn.Sequential( 222 | Conv2dBnRelu(in_ch, out_ch, kernel_size=1, stride=1, padding=0) 223 | ) 224 | 225 | self.down1 = Conv2dBnRelu(in_ch, 128, kernel_size=7, stride=2, padding=3) 226 | 227 | self.down2 = Conv2dBnRelu(128, 128, kernel_size=5, stride=2, padding=2) 228 | 229 | self.down3 = nn.Sequential( 230 | Conv2dBnRelu(128, 128, kernel_size=3, stride=2, padding=1), 231 | Conv2dBnRelu(128, 20, kernel_size=1, stride=1, padding=0) 232 | ) 233 | 234 | self.conv2 = Conv2dBnRelu(128, 20, kernel_size=1, stride=1, padding=0) 235 | self.conv1 = Conv2dBnRelu(128, 20, kernel_size=1, stride=1, padding=0) 236 | 237 | def forward(self, x): 238 | 239 | h = x.size()[2] 240 | w = x.size()[3] 241 | 242 | b1 = self.branch1(x) 243 | #b1 = Interpolate(size=(h, w), mode="bilinear")(b1) 244 | b1= interpolate(b1, size=(h, w), mode="bilinear", align_corners=True) 245 | 246 | mid = self.mid(x) 247 | 248 | x1 = self.down1(x) 249 | x2 = self.down2(x1) 250 | x3 = self.down3(x2) 251 | #x3 = Interpolate(size=(h // 4, w // 4), mode="bilinear")(x3) 252 | x3= interpolate(x3, size=(h // 4, w // 4), mode="bilinear", align_corners=True) 253 | x2 = self.conv2(x2) 254 | x = x2 + x3 255 | #x = Interpolate(size=(h // 2, w // 2), mode="bilinear")(x) 256 | x= interpolate(x, size=(h // 2, w // 2), mode="bilinear", align_corners=True) 257 | 258 | x1 = self.conv1(x1) 259 | x = x + x1 260 | #x = Interpolate(size=(h, w), mode="bilinear")(x) 261 | x= interpolate(x, size=(h, w), mode="bilinear", align_corners=True) 262 | 263 | x = torch.mul(x, mid) 264 | 265 | x = x + b1 266 | 267 | 268 | return x 269 | 270 | 271 | 272 | 273 | class Decoder (nn.Module): 274 | def __init__(self, num_classes): 275 | super().__init__() 276 | 277 | self.apn = APN_Module(in_ch=128,out_ch=20) 278 | #self.upsample = Interpolate(size=(512, 1024), mode="bilinear") 279 | #self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=4, stride=2, padding=1, output_padding=0, bias=True) 280 | #self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True) 281 | #self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True) 282 | 283 | def forward(self, input): 284 | 285 | output = self.apn(input) 286 | out = interpolate(output, size=(512, 1024), mode="bilinear", align_corners=True) 287 | #out = self.upsample(output) 288 | return out 289 | 290 | 291 | # LEDNet 292 | class Net(nn.Module): 293 | def __init__(self, num_classes, encoder=None): 294 | super().__init__() 295 | 296 | if (encoder == None): 297 | self.encoder = Encoder(num_classes) 298 | else: 299 | self.encoder = encoder 300 | self.decoder = Decoder(num_classes) 301 | 302 | def forward(self, input, only_encode=False): 303 | if only_encode: 304 | return self.encoder.forward(input, predict=True) 305 | else: 306 | output = self.encoder(input) 307 | return self.decoder.forward(output) 308 | -------------------------------------------------------------------------------- /train/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | import numpy as np 5 | import torch 6 | import math 7 | import sys 8 | 9 | cur_path = os.path.abspath(os.path.dirname(__file__)) 10 | root_path = os.path.split(cur_path)[0] 11 | sys.path.append(root_path) 12 | 13 | 14 | 15 | from PIL import Image, ImageOps 16 | from argparse import ArgumentParser 17 | 18 | from torch.optim import SGD, Adam, lr_scheduler 19 | from torch.autograd import Variable 20 | from torch.utils.data import DataLoader 21 | from torchvision.transforms import Compose, CenterCrop, Normalize, Resize, Pad 22 | from torchvision.transforms import ToTensor, ToPILImage 23 | 24 | from utils.dataset import VOC12,cityscapes 25 | from utils.transform import Relabel, ToLabel, Colorize 26 | from utils.visualize import Dashboard 27 | from utils.loss import CrossEntropyLoss2d 28 | 29 | import importlib 30 | from utils.iouEval import iouEval, getColorEntry 31 | 32 | from shutil import copyfile 33 | 34 | NUM_CHANNELS = 3 35 | NUM_CLASSES = 20 #pascal=22, cityscapes=20 36 | 37 | color_transform = Colorize(NUM_CLASSES) 38 | image_transform = ToPILImage() 39 | 40 | #Augmentations - different function implemented to perform random augments on both image and target 41 | class MyCoTransform(object): 42 | def __init__(self, enc, augment=True, height=512): 43 | self.enc=enc 44 | self.augment = augment 45 | self.height = height 46 | pass 47 | def __call__(self, input, target): 48 | # do something to both images 49 | input = Resize(self.height, Image.BILINEAR)(input) 50 | target = Resize(self.height, Image.NEAREST)(target) 51 | 52 | if(self.augment): 53 | # Random hflip 54 | hflip = random.random() 55 | if (hflip < 0.5): 56 | input = input.transpose(Image.FLIP_LEFT_RIGHT) 57 | target = target.transpose(Image.FLIP_LEFT_RIGHT) 58 | 59 | #Random translation 0-2 pixels (fill rest with padding 60 | transX = random.randint(-2, 2) 61 | transY = random.randint(-2, 2) 62 | 63 | input = ImageOps.expand(input, border=(transX,transY,0,0), fill=0) 64 | target = ImageOps.expand(target, border=(transX,transY,0,0), fill=255) #pad label filling with 255 65 | input = input.crop((0, 0, input.size[0]-transX, input.size[1]-transY)) 66 | target = target.crop((0, 0, target.size[0]-transX, target.size[1]-transY)) 67 | 68 | input = ToTensor()(input) 69 | if (self.enc): 70 | target = Resize(int(self.height/8), Image.NEAREST)(target) 71 | target = ToLabel()(target) 72 | target = Relabel(255, 19)(target) 73 | 74 | return input, target 75 | 76 | 77 | 78 | 79 | 80 | def train(args, model, enc=False): 81 | best_acc = 0 82 | 83 | #TODO: calculate weights by processing dataset histogram (now its being set by hand from the torch values) 84 | #create a loder to run all images and calculate histogram of labels, then create weight array using class balancing 85 | 86 | weight = torch.ones(NUM_CLASSES) 87 | if (enc): 88 | weight[0] = 2.3653597831726 89 | weight[1] = 4.4237880706787 90 | weight[2] = 2.9691488742828 91 | weight[3] = 5.3442072868347 92 | weight[4] = 5.2983593940735 93 | weight[5] = 5.2275490760803 94 | weight[6] = 5.4394111633301 95 | weight[7] = 5.3659925460815 96 | weight[8] = 3.4170460700989 97 | weight[9] = 5.2414722442627 98 | weight[10] = 4.7376127243042 99 | weight[11] = 5.2286224365234 100 | weight[12] = 5.455126285553 101 | weight[13] = 4.3019247055054 102 | weight[14] = 5.4264230728149 103 | weight[15] = 5.4331531524658 104 | weight[16] = 5.433765411377 105 | weight[17] = 5.4631009101868 106 | weight[18] = 5.3947434425354 107 | else: 108 | weight[0] = 2.8149201869965 109 | weight[1] = 6.9850029945374 110 | weight[2] = 3.7890393733978 111 | weight[3] = 9.9428062438965 112 | weight[4] = 9.7702074050903 113 | weight[5] = 9.5110931396484 114 | weight[6] = 10.311357498169 115 | weight[7] = 10.026463508606 116 | weight[8] = 4.6323022842407 117 | weight[9] = 9.5608062744141 118 | weight[10] = 7.8698215484619 119 | weight[11] = 9.5168733596802 120 | weight[12] = 10.373730659485 121 | weight[13] = 6.6616044044495 122 | weight[14] = 10.260489463806 123 | weight[15] = 10.287888526917 124 | weight[16] = 10.289801597595 125 | weight[17] = 10.405355453491 126 | weight[18] = 10.138095855713 127 | 128 | weight[19] = 0 129 | 130 | assert os.path.exists(args.datadir), "Error: datadir (dataset directory) could not be loaded" 131 | 132 | co_transform = MyCoTransform(enc, augment=True, height=args.height)#512) 133 | co_transform_val = MyCoTransform(enc, augment=False, height=args.height)#512) 134 | dataset_train = cityscapes(args.datadir, co_transform, 'train') 135 | dataset_val = cityscapes(args.datadir, co_transform_val, 'val') 136 | 137 | loader = DataLoader(dataset_train, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True) 138 | loader_val = DataLoader(dataset_val, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) 139 | 140 | if args.cuda: 141 | weight = weight.cuda() 142 | criterion = CrossEntropyLoss2d(weight) 143 | print(type(criterion)) 144 | 145 | savedir = f'../save/{args.savedir}' 146 | 147 | if (enc): 148 | automated_log_path = savedir + "/automated_log_encoder.txt" 149 | modeltxtpath = savedir + "/model_encoder.txt" 150 | else: 151 | automated_log_path = savedir + "/automated_log.txt" 152 | modeltxtpath = savedir + "/model.txt" 153 | 154 | if (not os.path.exists(automated_log_path)): #dont add first line if it exists 155 | with open(automated_log_path, "a") as myfile: 156 | myfile.write("Epoch\t\tTrain-loss\t\tTest-loss\t\tTrain-IoU\t\tTest-IoU\t\tlearningRate") 157 | 158 | with open(modeltxtpath, "w") as myfile: 159 | myfile.write(str(model)) 160 | 161 | 162 | #TODO: reduce memory in first gpu: https://discuss.pytorch.org/t/multi-gpu-training-memory-usage-in-balance/4163/4 163 | #https://github.com/pytorch/pytorch/issues/1893 164 | 165 | #optimizer = Adam(model.parameters(), 5e-4, (0.9, 0.999), eps=1e-08, weight_decay=2e-4) ## scheduler 1 166 | optimizer = Adam(model.parameters(), 5e-4, (0.9, 0.999), eps=1e-08, weight_decay=1e-4) ## scheduler 2 167 | 168 | start_epoch = 1 169 | if args.resume: 170 | #Must load weights, optimizer, epoch and best value. 171 | if enc: 172 | filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar' 173 | else: 174 | filenameCheckpoint = savedir + '/checkpoint.pth.tar' 175 | 176 | assert os.path.exists(filenameCheckpoint), "Error: resume option was used but checkpoint was not found in folder" 177 | checkpoint = torch.load(filenameCheckpoint) 178 | start_epoch = checkpoint['epoch'] 179 | model.load_state_dict(checkpoint['state_dict']) 180 | optimizer.load_state_dict(checkpoint['optimizer']) 181 | best_acc = checkpoint['best_acc'] 182 | print("=> Loaded checkpoint at epoch {})".format(checkpoint['epoch'])) 183 | 184 | #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5) # set up scheduler ## scheduler 1 185 | lambda1 = lambda epoch: pow((1-((epoch-1)/args.num_epochs)),0.9) ## scheduler 2 186 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) ## scheduler 2 187 | 188 | if args.visualize and args.steps_plot > 0: 189 | board = Dashboard(args.port) 190 | 191 | for epoch in range(start_epoch, args.num_epochs+1): 192 | print("----- TRAINING - EPOCH", epoch, "-----") 193 | 194 | scheduler.step(epoch) ## scheduler 2 195 | 196 | epoch_loss = [] 197 | time_train = [] 198 | 199 | doIouTrain = args.iouTrain 200 | doIouVal = args.iouVal 201 | 202 | if (doIouTrain): 203 | iouEvalTrain = iouEval(NUM_CLASSES) 204 | 205 | usedLr = 0 206 | for param_group in optimizer.param_groups: 207 | print("LEARNING RATE: ", param_group['lr']) 208 | usedLr = float(param_group['lr']) 209 | 210 | model.train() 211 | for step, (images, labels) in enumerate(loader): 212 | 213 | start_time = time.time() 214 | 215 | imgs_batch = images.shape[0] 216 | if imgs_batch != args.batch_size: 217 | break 218 | 219 | if args.cuda: 220 | inputs = images.cuda() 221 | targets = labels.cuda() 222 | 223 | outputs = model(inputs, only_encode=enc) 224 | 225 | #print("targets", np.unique(targets[:, 0].cpu().data.numpy())) 226 | 227 | optimizer.zero_grad() 228 | loss = criterion(outputs, targets[:, 0]) 229 | 230 | loss.backward() 231 | optimizer.step() 232 | 233 | epoch_loss.append(loss.item()) 234 | time_train.append(time.time() - start_time) 235 | 236 | if (doIouTrain): 237 | #start_time_iou = time.time() 238 | iouEvalTrain.addBatch(outputs.max(1)[1].unsqueeze(1).data, targets.data) 239 | #print ("Time to add confusion matrix: ", time.time() - start_time_iou) 240 | 241 | #print(outputs.size()) 242 | if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0: 243 | start_time_plot = time.time() 244 | image = inputs[0].cpu().data 245 | #image[0] = image[0] * .229 + .485 246 | #image[1] = image[1] * .224 + .456 247 | #image[2] = image[2] * .225 + .406 248 | #print("output", np.unique(outputs[0].cpu().max(0)[1].data.numpy())) 249 | board.image(image, f'input (epoch: {epoch}, step: {step})') 250 | if isinstance(outputs, list): #merge gpu tensors 251 | board.image(color_transform(outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)), 252 | f'output (epoch: {epoch}, step: {step})') 253 | else: 254 | board.image(color_transform(outputs[0].cpu().max(0)[1].data.unsqueeze(0)), 255 | f'output (epoch: {epoch}, step: {step})') 256 | board.image(color_transform(targets[0].cpu().data), 257 | f'target (epoch: {epoch}, step: {step})') 258 | print ("Time to paint images: ", time.time() - start_time_plot) 259 | if args.steps_loss > 0 and step % args.steps_loss == 0: 260 | average = sum(epoch_loss) / len(epoch_loss) 261 | print(f'loss: {average:0.4} (epoch: {epoch}, step: {step})', 262 | "// Avg time/img: %.4f s" % (sum(time_train) / len(time_train) / args.batch_size)) 263 | 264 | 265 | average_epoch_loss_train = sum(epoch_loss) / len(epoch_loss) 266 | 267 | iouTrain = 0 268 | if (doIouTrain): 269 | iouTrain, iou_classes = iouEvalTrain.getIoU() 270 | iouStr = getColorEntry(iouTrain)+'{:0.2f}'.format(iouTrain*100) + '\033[0m' 271 | print ("EPOCH IoU on TRAIN set: ", iouStr, "%") 272 | 273 | #Validate on 500 val images after each epoch of training 274 | print("----- VALIDATING - EPOCH", epoch, "-----") 275 | model.eval() 276 | epoch_loss_val = [] 277 | time_val = [] 278 | 279 | if (doIouVal): 280 | iouEvalVal = iouEval(NUM_CLASSES) 281 | 282 | for step, (images, labels) in enumerate(loader_val): 283 | start_time = time.time() 284 | 285 | imgs_batch = images.shape[0] 286 | if imgs_batch != args.batch_size: 287 | break 288 | 289 | if args.cuda: 290 | images = images.cuda() 291 | labels = labels.cuda() 292 | 293 | with torch.no_grad(): 294 | inputs = Variable(images) 295 | targets = Variable(labels) 296 | 297 | outputs = model(inputs, only_encode=enc) 298 | 299 | loss = criterion(outputs, targets[:, 0]) 300 | epoch_loss_val.append(loss.item()) 301 | time_val.append(time.time() - start_time) 302 | 303 | 304 | #Add batch to calculate TP, FP and FN for iou estimation 305 | if (doIouVal): 306 | #start_time_iou = time.time() 307 | iouEvalVal.addBatch(outputs.max(1)[1].unsqueeze(1).data, targets.data) 308 | #print ("Time to add confusion matrix: ", time.time() - start_time_iou) 309 | 310 | if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0: 311 | start_time_plot = time.time() 312 | image = inputs[0].cpu().data 313 | board.image(image, f'VAL input (epoch: {epoch}, step: {step})') 314 | if isinstance(outputs, list): #merge gpu tensors 315 | board.image(color_transform(outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)), 316 | f'VAL output (epoch: {epoch}, step: {step})') 317 | else: 318 | board.image(color_transform(outputs[0].cpu().max(0)[1].data.unsqueeze(0)), 319 | f'VAL output (epoch: {epoch}, step: {step})') 320 | board.image(color_transform(targets[0].cpu().data), 321 | f'VAL target (epoch: {epoch}, step: {step})') 322 | print ("Time to paint images: ", time.time() - start_time_plot) 323 | if args.steps_loss > 0 and step % args.steps_loss == 0: 324 | average = sum(epoch_loss_val) / len(epoch_loss_val) 325 | print(f'VAL loss: {average:0.4} (epoch: {epoch}, step: {step})', 326 | "// Avg time/img: %.4f s" % (sum(time_val) / len(time_val) / args.batch_size)) 327 | 328 | 329 | average_epoch_loss_val = sum(epoch_loss_val) / len(epoch_loss_val) 330 | #scheduler.step(average_epoch_loss_val, epoch) ## scheduler 1 # update lr if needed 331 | 332 | iouVal = 0 333 | if (doIouVal): 334 | iouVal, iou_classes = iouEvalVal.getIoU() 335 | iouStr = getColorEntry(iouVal)+'{:0.2f}'.format(iouVal*100) + '\033[0m' 336 | print ("EPOCH IoU on VAL set: ", iouStr, "%") 337 | 338 | 339 | # remember best valIoU and save checkpoint 340 | if iouVal == 0: 341 | current_acc = -average_epoch_loss_val 342 | else: 343 | current_acc = iouVal 344 | is_best = current_acc > best_acc 345 | best_acc = max(current_acc, best_acc) 346 | if enc: 347 | filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar' 348 | filenameBest = savedir + '/model_best_enc.pth.tar' 349 | else: 350 | filenameCheckpoint = savedir + '/checkpoint.pth.tar' 351 | filenameBest = savedir + '/model_best.pth.tar' 352 | save_checkpoint({ 353 | 'epoch': epoch + 1, 354 | 'arch': str(model), 355 | 'state_dict': model.state_dict(), 356 | 'best_acc': best_acc, 357 | 'optimizer' : optimizer.state_dict(), 358 | }, is_best, filenameCheckpoint, filenameBest) 359 | 360 | #SAVE MODEL AFTER EPOCH 361 | if (enc): 362 | filename = f'{savedir}/model_encoder-{epoch:03}.pth' 363 | filenamebest = f'{savedir}/model_encoder_best.pth' 364 | else: 365 | filename = f'{savedir}/model-{epoch:03}.pth' 366 | filenamebest = f'{savedir}/model_best.pth' 367 | if args.epochs_save > 0 and step > 0 and step % args.epochs_save == 0: 368 | torch.save(model.state_dict(), filename) 369 | print(f'save: {filename} (epoch: {epoch})') 370 | if (is_best): 371 | torch.save(model.state_dict(), filenamebest) 372 | print(f'save: {filenamebest} (epoch: {epoch})') 373 | if (not enc): 374 | with open(savedir + "/best.txt", "w") as myfile: 375 | myfile.write("Best epoch is %d, with Val-IoU= %.4f" % (epoch, iouVal)) 376 | else: 377 | with open(savedir + "/best_encoder.txt", "w") as myfile: 378 | myfile.write("Best epoch is %d, with Val-IoU= %.4f" % (epoch, iouVal)) 379 | 380 | #SAVE TO FILE A ROW WITH THE EPOCH RESULT (train loss, val loss, train IoU, val IoU) 381 | #Epoch Train-loss Test-loss Train-IoU Test-IoU learningRate 382 | with open(automated_log_path, "a") as myfile: 383 | myfile.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.8f" % (epoch, average_epoch_loss_train, average_epoch_loss_val, iouTrain, iouVal, usedLr )) 384 | 385 | return(model) #return model (convenience for encoder-decoder training) 386 | 387 | def save_checkpoint(state, is_best, filenameCheckpoint, filenameBest): 388 | torch.save(state, filenameCheckpoint) 389 | if is_best: 390 | print ("Saving model as best") 391 | torch.save(state, filenameBest) 392 | 393 | 394 | def main(args): 395 | savedir = f'../save/{args.savedir}' 396 | 397 | if not os.path.exists(savedir): 398 | os.makedirs(savedir) 399 | 400 | with open(savedir + '/opts.txt', "w") as myfile: 401 | myfile.write(str(args)) 402 | 403 | #Load Model 404 | assert os.path.exists(args.model + ".py"), "Error: model definition not found" 405 | model_file = importlib.import_module(args.model) 406 | model = model_file.Net(NUM_CLASSES) 407 | copyfile(args.model + ".py", savedir + '/' + args.model + ".py") 408 | 409 | if args.cuda: 410 | model = torch.nn.DataParallel(model).cuda() 411 | 412 | if args.state: 413 | #if args.state is provided then load this state for training 414 | #Note: this only loads initialized weights. If you want to resume a training use "--resume" option!! 415 | """ 416 | try: 417 | model.load_state_dict(torch.load(args.state)) 418 | except AssertionError: 419 | model.load_state_dict(torch.load(args.state, 420 | map_location=lambda storage, loc: storage)) 421 | #When model is saved as DataParallel it adds a model. to each key. To remove: 422 | #state_dict = {k.partition('model.')[2]: v for k,v in state_dict} 423 | #https://discuss.pytorch.org/t/prefix-parameter-names-in-saved-model-if-trained-by-multi-gpu/494 424 | """ 425 | def load_my_state_dict(model, state_dict): #custom function to load model when not all dict keys are there 426 | own_state = model.state_dict() 427 | for name, param in state_dict.items(): 428 | if name not in own_state: 429 | continue 430 | own_state[name].copy_(param) 431 | return model 432 | 433 | #print(torch.load(args.state)) 434 | model = load_my_state_dict(model, torch.load(args.state)) 435 | 436 | """ 437 | def weights_init(m): 438 | classname = m.__class__.__name__ 439 | if classname.find('Conv') != -1: 440 | #m.weight.data.normal_(0.0, 0.02) 441 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 442 | m.weight.data.normal_(0, math.sqrt(2. / n)) 443 | elif classname.find('BatchNorm') != -1: 444 | #m.weight.data.normal_(1.0, 0.02) 445 | m.weight.data.fill_(1) 446 | m.bias.data.fill_(0) 447 | 448 | #TO ACCESS MODEL IN DataParallel: next(model.children()) 449 | #next(model.children()).decoder.apply(weights_init) 450 | #Reinitialize weights for decoder 451 | 452 | next(model.children()).decoder.layers.apply(weights_init) 453 | next(model.children()).decoder.output_conv.apply(weights_init) 454 | 455 | #print(model.state_dict()) 456 | f = open('weights5.txt', 'w') 457 | f.write(str(model.state_dict())) 458 | f.close() 459 | """ 460 | 461 | #train(args, model) 462 | if (not args.decoder): 463 | print("========== ENCODER TRAINING ===========") 464 | model = train(args, model, True) #Train encoder 465 | #CAREFUL: for some reason, after training encoder alone, the decoder gets weights=0. 466 | #We must reinit decoder weights or reload network passing only encoder in order to train decoder 467 | print("========== DECODER TRAINING ===========") 468 | if (not args.state): 469 | if args.pretrainedEncoder: 470 | print("Loading encoder pretrained in imagenet") 471 | from lednet_imagenet import LEDNet as LEDNet_imagenet 472 | pretrainedEnc = torch.nn.DataParallel(LEDNet_imagenet(1000)) 473 | pretrainedEnc.load_state_dict(torch.load(args.pretrainedEncoder)['state_dict']) 474 | pretrainedEnc = next(pretrainedEnc.children()).features.encoder 475 | if (not args.cuda): 476 | pretrainedEnc = pretrainedEnc.cpu() #because loaded encoder is probably saved in cuda 477 | else: 478 | pretrainedEnc = next(model.children()).encoder 479 | model = model_file.Net(NUM_CLASSES, encoder=pretrainedEnc) #Add decoder to encoder 480 | if args.cuda: 481 | model = torch.nn.DataParallel(model).cuda() 482 | #When loading encoder reinitialize weights for decoder because they are set to 0 when training dec 483 | model = train(args, model, False) #Train decoder 484 | print("========== TRAINING FINISHED ===========") 485 | 486 | if __name__ == '__main__': 487 | parser = ArgumentParser() 488 | parser.add_argument('--cuda', action='store_true', default=True) #NOTE: cpu-only has not been tested so you might have to change code if you deactivate this flag 489 | parser.add_argument('--model', default= "lednet") 490 | parser.add_argument('--state') 491 | 492 | parser.add_argument('--port', type=int, default=8097) 493 | parser.add_argument('--datadir', default=os.getenv("HOME") + "/datasets/cityscapes/") 494 | parser.add_argument('--height', type=int, default=512) 495 | parser.add_argument('--num-epochs', type=int, default=300) 496 | parser.add_argument('--num-workers', type=int, default=4) 497 | parser.add_argument('--batch-size', type=int, default=5) 498 | parser.add_argument('--steps-loss', type=int, default=50) 499 | parser.add_argument('--steps-plot', type=int, default=50) 500 | parser.add_argument('--epochs-save', type=int, default=0) #You can use this value to save model every X epochs 501 | parser.add_argument('--savedir', required=True) 502 | parser.add_argument('--decoder', action='store_true') 503 | parser.add_argument('--pretrainedEncoder') #, default=" ") 504 | parser.add_argument('--visualize', action='store_true') 505 | 506 | parser.add_argument('--iouTrain', action='store_true', default=True) #recommended: False (takes more time to train otherwise) 507 | parser.add_argument('--iouVal', action='store_true', default=True) 508 | parser.add_argument('--resume', action='store_true') #Use this flag to load last checkpoint for training 509 | 510 | main(parser.parse_args()) 511 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyufenfei/LEDNet/5d900d9cfabb3091c952b79be34246aea0608e42/utils/__init__.py -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | from PIL import Image 5 | 6 | from torch.utils.data import Dataset 7 | 8 | EXTENSIONS = ['.jpg', '.png'] 9 | 10 | def load_image(file): 11 | return Image.open(file) 12 | 13 | def is_image(filename): 14 | return any(filename.endswith(ext) for ext in EXTENSIONS) 15 | 16 | def is_label(filename): 17 | return filename.endswith("_labelTrainIds.png") 18 | 19 | def image_path(root, basename, extension): 20 | return os.path.join(root, f'{basename}{extension}') 21 | 22 | def image_path_city(root, name): 23 | return os.path.join(root, f'{name}') 24 | 25 | def image_basename(filename): 26 | return os.path.basename(os.path.splitext(filename)[0]) 27 | 28 | class VOC12(Dataset): 29 | 30 | def __init__(self, root, input_transform=None, target_transform=None): 31 | self.images_root = os.path.join(root, 'images') 32 | self.labels_root = os.path.join(root, 'labels') 33 | 34 | self.filenames = [image_basename(f) 35 | for f in os.listdir(self.labels_root) if is_image(f)] 36 | self.filenames.sort() 37 | 38 | self.input_transform = input_transform 39 | self.target_transform = target_transform 40 | 41 | def __getitem__(self, index): 42 | filename = self.filenames[index] 43 | 44 | with open(image_path(self.images_root, filename, '.jpg'), 'rb') as f: 45 | image = load_image(f).convert('RGB') 46 | with open(image_path(self.labels_root, filename, '.png'), 'rb') as f: 47 | label = load_image(f).convert('P') 48 | 49 | if self.input_transform is not None: 50 | image = self.input_transform(image) 51 | if self.target_transform is not None: 52 | label = self.target_transform(label) 53 | 54 | return image, label 55 | 56 | def __len__(self): 57 | return len(self.filenames) 58 | 59 | 60 | 61 | 62 | class cityscapes(Dataset): 63 | 64 | def __init__(self, root, co_transform=None, subset='train'): 65 | self.images_root = os.path.join(root, 'leftImg8bit/') 66 | self.labels_root = os.path.join(root, 'gtFine/') 67 | #self.labels_root = os.path.join(root, 'gtCoarse/') 68 | self.images_root += subset 69 | self.labels_root += subset 70 | 71 | print (self.images_root) 72 | #self.filenames = [image_basename(f) for f in os.listdir(self.images_root) if is_image(f)] 73 | self.filenames = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(self.images_root)) for f in fn if is_image(f)] 74 | self.filenames.sort() 75 | 76 | #[os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(".")) for f in fn] 77 | #self.filenamesGt = [image_basename(f) for f in os.listdir(self.labels_root) if is_image(f)] 78 | self.filenamesGt = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(self.labels_root)) for f in fn if is_label(f)] 79 | self.filenamesGt.sort() 80 | 81 | self.co_transform = co_transform # ADDED THIS 82 | 83 | 84 | def __getitem__(self, index): 85 | filename = self.filenames[index] 86 | filenameGt = self.filenamesGt[index] 87 | 88 | with open(image_path_city(self.images_root, filename), 'rb') as f: 89 | image = load_image(f).convert('RGB') 90 | with open(image_path_city(self.labels_root, filenameGt), 'rb') as f: 91 | label = load_image(f).convert('P') 92 | 93 | if self.co_transform is not None: 94 | image, label = self.co_transform(image, label) 95 | 96 | return image, label 97 | 98 | def __len__(self): 99 | return len(self.filenames) 100 | 101 | -------------------------------------------------------------------------------- /utils/iouEval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class iouEval: 4 | 5 | def __init__(self, nClasses, ignoreIndex=19): 6 | self.nClasses = nClasses 7 | self.ignoreIndex = ignoreIndex if nClasses>ignoreIndex else -1 #if ignoreIndex is larger than nClasses, consider no ignoreIndex 8 | self.reset() 9 | 10 | def reset (self): 11 | classes = self.nClasses if self.ignoreIndex==-1 else self.nClasses-1 12 | self.tp = torch.zeros(classes).double() 13 | self.fp = torch.zeros(classes).double() 14 | self.fn = torch.zeros(classes).double() 15 | 16 | def addBatch(self, x, y): #x=preds, y=targets 17 | #sizes should be "batch_size x nClasses x H x W" 18 | 19 | #print ("X is cuda: ", x.is_cuda) 20 | #print ("Y is cuda: ", y.is_cuda) 21 | 22 | if (x.is_cuda or y.is_cuda): 23 | x = x.cuda() 24 | y = y.cuda() 25 | 26 | #if size is "batch_size x 1 x H x W" scatter to onehot 27 | if (x.size(1) == 1): 28 | x_onehot = torch.zeros(x.size(0), self.nClasses, x.size(2), x.size(3)) 29 | if x.is_cuda: 30 | x_onehot = x_onehot.cuda() 31 | x_onehot.scatter_(1, x, 1).float() 32 | else: 33 | x_onehot = x.float() 34 | 35 | if (y.size(1) == 1): 36 | y_onehot = torch.zeros(y.size(0), self.nClasses, y.size(2), y.size(3)) 37 | if y.is_cuda: 38 | y_onehot = y_onehot.cuda() 39 | y_onehot.scatter_(1, y, 1).float() 40 | else: 41 | y_onehot = y.float() 42 | 43 | if (self.ignoreIndex != -1): 44 | ignores = y_onehot[:,self.ignoreIndex].unsqueeze(1) 45 | x_onehot = x_onehot[:, :self.ignoreIndex] 46 | y_onehot = y_onehot[:, :self.ignoreIndex] 47 | else: 48 | ignores=0 49 | 50 | #print(type(x_onehot)) 51 | #print(type(y_onehot)) 52 | #print(x_onehot.size()) 53 | #print(y_onehot.size()) 54 | 55 | tpmult = x_onehot * y_onehot #times prediction and gt coincide is 1 56 | tp = torch.sum(torch.sum(torch.sum(tpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze() 57 | fpmult = x_onehot * (1-y_onehot-ignores) #times prediction says its that class and gt says its not (subtracting cases when its ignore label!) 58 | fp = torch.sum(torch.sum(torch.sum(fpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze() 59 | fnmult = (1-x_onehot) * (y_onehot) #times prediction says its not that class and gt says it is 60 | fn = torch.sum(torch.sum(torch.sum(fnmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze() 61 | 62 | self.tp += tp.double().cpu() 63 | self.fp += fp.double().cpu() 64 | self.fn += fn.double().cpu() 65 | 66 | def getIoU(self): 67 | num = self.tp 68 | den = self.tp + self.fp + self.fn + 1e-15 69 | iou = num / den 70 | return torch.mean(iou), iou #returns "iou mean", "iou per class" 71 | 72 | # Class for colors 73 | class colors: 74 | RED = '\033[31;1m' 75 | GREEN = '\033[32;1m' 76 | YELLOW = '\033[33;1m' 77 | BLUE = '\033[34;1m' 78 | MAGENTA = '\033[35;1m' 79 | CYAN = '\033[36;1m' 80 | BOLD = '\033[1m' 81 | UNDERLINE = '\033[4m' 82 | ENDC = '\033[0m' 83 | 84 | # Colored value output if colorized flag is activated. 85 | def getColorEntry(val): 86 | if not isinstance(val, float): 87 | return colors.ENDC 88 | if (val < .20): 89 | return colors.RED 90 | elif (val < .40): 91 | return colors.YELLOW 92 | elif (val < .60): 93 | return colors.BLUE 94 | elif (val < .80): 95 | return colors.CYAN 96 | else: 97 | return colors.GREEN 98 | 99 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CrossEntropyLoss2d(torch.nn.Module): 7 | 8 | def __init__(self, weight=None): 9 | super(CrossEntropyLoss2d,self).__init__() 10 | 11 | self.loss = nn.NLLLoss(weight) 12 | 13 | def forward(self, outputs, targets): 14 | return self.loss(F.log_softmax(outputs, dim=1), targets) 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /utils/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from PIL import Image 5 | 6 | def colormap_cityscapes(n): 7 | cmap=np.zeros([n, 3]).astype(np.uint8) 8 | cmap[0,:] = np.array([128, 64,128]) 9 | cmap[1,:] = np.array([244, 35,232]) 10 | cmap[2,:] = np.array([ 70, 70, 70]) 11 | cmap[3,:] = np.array([ 102,102,156]) 12 | cmap[4,:] = np.array([ 190,153,153]) 13 | cmap[5,:] = np.array([ 153,153,153]) 14 | 15 | cmap[6,:] = np.array([ 250,170, 30]) 16 | cmap[7,:] = np.array([ 220,220, 0]) 17 | cmap[8,:] = np.array([ 107,142, 35]) 18 | cmap[9,:] = np.array([ 152,251,152]) 19 | cmap[10,:] = np.array([ 70,130,180]) 20 | 21 | cmap[11,:] = np.array([ 220, 20, 60]) 22 | cmap[12,:] = np.array([ 255, 0, 0]) 23 | cmap[13,:] = np.array([ 0, 0,142]) 24 | cmap[14,:] = np.array([ 0, 0, 70]) 25 | cmap[15,:] = np.array([ 0, 60,100]) 26 | 27 | cmap[16,:] = np.array([ 0, 80,100]) 28 | cmap[17,:] = np.array([ 0, 0,230]) 29 | cmap[18,:] = np.array([ 119, 11, 32]) 30 | cmap[19,:] = np.array([ 0, 0, 0]) 31 | 32 | return cmap 33 | 34 | 35 | def colormap(n): 36 | cmap=np.zeros([n, 3]).astype(np.uint8) 37 | 38 | for i in np.arange(n): 39 | r, g, b = np.zeros(3) 40 | 41 | for j in np.arange(8): 42 | r = r + (1<<(7-j))*((i&(1<<(3*j))) >> (3*j)) 43 | g = g + (1<<(7-j))*((i&(1<<(3*j+1))) >> (3*j+1)) 44 | b = b + (1<<(7-j))*((i&(1<<(3*j+2))) >> (3*j+2)) 45 | 46 | cmap[i,:] = np.array([r, g, b]) 47 | 48 | return cmap 49 | 50 | class Relabel: 51 | 52 | def __init__(self, olabel, nlabel): 53 | self.olabel = olabel 54 | self.nlabel = nlabel 55 | 56 | def __call__(self, tensor): 57 | assert (isinstance(tensor, torch.LongTensor) or isinstance(tensor, torch.ByteTensor)) , 'tensor needs to be LongTensor' 58 | tensor[tensor == self.olabel] = self.nlabel 59 | return tensor 60 | 61 | 62 | class ToLabel: 63 | 64 | def __call__(self, image): 65 | return torch.from_numpy(np.array(image)).long().unsqueeze(0) 66 | 67 | 68 | class Colorize: 69 | 70 | def __init__(self, n=22): 71 | #self.cmap = colormap(256) 72 | self.cmap = colormap_cityscapes(256) 73 | self.cmap[n] = self.cmap[-1] 74 | self.cmap = torch.from_numpy(self.cmap[:n]) 75 | 76 | def __call__(self, gray_image): 77 | size = gray_image.size() 78 | #print(size) 79 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 80 | #color_image = torch.ByteTensor(3, size[0], size[1]).fill_(0) 81 | 82 | #for label in range(1, len(self.cmap)): 83 | for label in range(0, len(self.cmap)): 84 | mask = gray_image[0] == label 85 | #mask = gray_image == label 86 | 87 | color_image[0][mask] = self.cmap[label][0] 88 | color_image[1][mask] = self.cmap[label][1] 89 | color_image[2][mask] = self.cmap[label][2] 90 | 91 | return color_image 92 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from torch.autograd import Variable 4 | 5 | from visdom import Visdom 6 | 7 | class Dashboard: 8 | 9 | def __init__(self, port): 10 | self.vis = Visdom(port=port) 11 | 12 | def loss(self, losses, title): 13 | x = np.arange(1, len(losses)+1, 1) 14 | 15 | self.vis.line(losses, x, env='loss', opts=dict(title=title)) 16 | 17 | def image(self, image, title): 18 | if image.is_cuda: 19 | image = image.cpu() 20 | if isinstance(image, Variable): 21 | image = image.data 22 | image = image.numpy() 23 | 24 | self.vis.image(image, env='images', opts=dict(title=title)) --------------------------------------------------------------------------------