├── .gitignore ├── LICENSE ├── README.md └── src ├── data ├── __init__.py ├── color2class.py └── data_loader.py ├── model ├── __init__.py ├── enc_net.py ├── resnet18.py └── utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # own 107 | *.tags 108 | *.tags1 109 | /data 110 | /results 111 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 shimao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-EncNet -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader 2 | -------------------------------------------------------------------------------- /src/data/color2class.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import pandas as pd 4 | import numpy as np 5 | from skimage import io 6 | from matplotlib import pyplot as plt 7 | 8 | 9 | data_root_dir = './data' 10 | label_table_name = 'label_color' 11 | image_dir = 'CamSeq01' 12 | image_dir_path = Path(data_root_dir, image_dir) 13 | 14 | label_table = pd.read_csv(Path(data_root_dir, label_table_name)) 15 | label_image_file_names =\ 16 | [f for f in os.listdir(image_dir_path) if '_L.png' in f] 17 | 18 | for image_file_name in label_image_file_names: 19 | file_path = Path(image_dir_path, image_file_name) 20 | print('processing {}'.format(file_path)) 21 | img = io.imread(file_path) 22 | # plt.figure() 23 | # io.imshow(img) 24 | # plt.show() 25 | 26 | class_map = np.empty((img.shape[0], img.shape[1]), dtype='i') 27 | for i_row in range(len(label_table)): 28 | class_color = label_table.iloc[i_row] 29 | idx = np.where( 30 | (img[:, :, 0] == class_color[0]) # r 31 | & (img[:, :, 1] == class_color[1]) # g 32 | & (img[:, :, 2] == class_color[2])) # b 33 | class_map[idx] = i_row 34 | label_file_name = Path(image_dir_path, 35 | os.path.splitext(image_file_name)[0]+'.npz') 36 | np.savez(label_file_name, data=class_map) 37 | -------------------------------------------------------------------------------- /src/data/data_loader.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from skimage.color import rgba2rgb 3 | import numpy as np 4 | from sklearn.model_selection import train_test_split 5 | from skimage import io 6 | import torch 7 | from torchvision import datasets 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class CamVidDataset(Dataset): 12 | def __init__(self, image_file_names, root_dir, 13 | subset=False, transform=None): 14 | super().__init__() 15 | self.image_file_names = image_file_names 16 | self.root_dir = root_dir 17 | self.transform = transform 18 | 19 | def __len__(self): 20 | return len(self.image_file_names) 21 | 22 | def __getitem__(self, idx): 23 | img_name = self.image_file_names[idx] 24 | 25 | image_path = Path(self.root_dir, img_name) 26 | label_path = Path(self.root_dir, img_name.replace('.png', '_L.npz')) 27 | 28 | image = io.imread(image_path) 29 | image = image.transpose(2, 0, 1) 30 | label = np.load(label_path)['data'] 31 | 32 | image = image[:, ::4, ::4] 33 | return torch.FloatTensor(image), torch.LongTensor(label[::4, ::4]) 34 | 35 | 36 | def loader(dataset, batch_size, shuffle=True): 37 | loader = torch.utils.data.DataLoader( 38 | dataset, 39 | batch_size=batch_size, 40 | shuffle=shuffle, 41 | num_workers=4) 42 | return loader 43 | 44 | # def set_data_loader(args, train=True): 45 | # dataset = load_data(args) 46 | # train_data, test_data = train_test_split(dataset, test_size=.2) 47 | # 48 | # assert train_data 49 | # assert test_data 50 | # 51 | # if train: 52 | # dataset = train_data 53 | # else: 54 | # dataset = test_data 55 | # 56 | # data_loader = torch.utils.data.DataLoader( 57 | # dataset, batch_size=args.batch_size, 58 | # shuffle=True, num_workers=int(args.workers)) 59 | # return data_loader 60 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import resnet18 2 | from . import utils 3 | -------------------------------------------------------------------------------- /src/model/enc_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | model_urls = { 8 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 9 | } 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 13 | """3x3 convolution with padding""" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False, dilation=dilation) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(inplanes, planes, stride) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.selu = nn.SELU(inplace=True) 26 | self.conv2 = conv3x3(planes, planes) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | residual = x 33 | 34 | out = self.conv1(x) 35 | # out = self.bn1(out) 36 | out = self.selu(out) 37 | 38 | out = self.conv2(out) 39 | # out = self.bn2(out) 40 | 41 | if self.downsample is not None: 42 | residual = self.downsample(x) 43 | 44 | out += residual 45 | out = self.selu(out) 46 | 47 | return out 48 | 49 | 50 | class Bottleneck(nn.Module): 51 | expansion = 4 52 | 53 | def __init__(self, inplanes, planes, stride=1, downsample=None): 54 | super(Bottleneck, self).__init__() 55 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 56 | self.bn1 = nn.BatchNorm2d(planes) 57 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 58 | padding=1, bias=False) 59 | self.bn2 = nn.BatchNorm2d(planes) 60 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 61 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 62 | self.relu = nn.SELU(inplace=True) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | residual = x 68 | 69 | out = self.conv1(x) 70 | # out = self.bn1(out) 71 | out = self.selu(out) 72 | 73 | out = self.conv2(out) 74 | # out = self.bn2(out) 75 | out = self.selu(out) 76 | 77 | out = self.conv3(out) 78 | # out = self.bn3(out) 79 | 80 | if self.downsample is not None: 81 | residual = self.downsample(x) 82 | 83 | out += residual 84 | out = self.selu(out) 85 | 86 | return out 87 | 88 | 89 | class ELayer(nn.Module): 90 | def __init__(self, fc_input, pool_kernel, n_classes): 91 | super().__init__() 92 | self.avgpool = nn.AdaptiveAvgPool2d(1) 93 | self.fc = nn.Linear(fc_input, n_classes) 94 | 95 | def forward(self, x): 96 | x = self.avgpool(x) 97 | x = x.view(x.size(0), -1) 98 | return F.sigmoid(self.fc(x)) 99 | 100 | 101 | class Encoder(nn.Module): 102 | def __init__(self, block, layers, num_classes=32): 103 | self.inplanes = 64 104 | super().__init__() 105 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 106 | bias=False) 107 | self.bn1 = nn.BatchNorm2d(64) 108 | self.selu = nn.SELU(inplace=True) 109 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 110 | self.layer1 = self._make_layer(block, 64, layers[0]) 111 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 112 | self.layer3 = self._make_layer(block, 256, layers[2], dilation=2) 113 | self.layer4 = self._make_layer( 114 | block, num_classes, layers[3], dilation=4) 115 | 116 | self.es1 = ELayer(256, 7, num_classes) 117 | 118 | self.avgpool = nn.AdaptiveAvgPool2d(1) 119 | self.fc1 = nn.Linear(num_classes * block.expansion, 256) 120 | self.fc_out = nn.Linear(256, num_classes) 121 | 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv2d): 124 | nn.init.kaiming_normal_( 125 | m.weight, mode='fan_out', nonlinearity='relu') 126 | elif isinstance(m, nn.BatchNorm2d): 127 | nn.init.constant_(m.weight, 1) 128 | nn.init.constant_(m.bias, 0) 129 | 130 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 131 | downsample = None 132 | if stride != 1 or self.inplanes != planes * block.expansion: 133 | downsample = nn.Sequential( 134 | nn.Conv2d(self.inplanes, planes * block.expansion, 135 | kernel_size=1, stride=stride, bias=False, 136 | dilation=dilation), 137 | # nn.BatchNorm2d(planes * block.expansion), 138 | ) 139 | 140 | layers = [] 141 | layers.append(block(self.inplanes, planes, stride, downsample)) 142 | self.inplanes = planes * block.expansion 143 | for i in range(1, blocks): 144 | layers.append(block(self.inplanes, planes)) 145 | 146 | return nn.Sequential(*layers) 147 | 148 | def forward(self, x): 149 | x = self.selu(self.bn1(self.conv1(x))) 150 | x = self.maxpool(x) 151 | 152 | x = self.layer1(x) 153 | x = self.layer2(x) 154 | x = self.layer3(x) 155 | self.se2 = self.es1(x) 156 | 157 | x = self.layer4(x) 158 | self.feature_map = x 159 | 160 | x = self.avgpool(x) 161 | x = x.view(x.size(0), -1) 162 | x = self.fc1(x) 163 | 164 | x = self.fc_out(x) 165 | self.se1 = F.sigmoid(x) 166 | 167 | return x 168 | 169 | 170 | class Decoder(nn.Module): 171 | def __init__(self, num_classes): 172 | super().__init__() 173 | # self.cnv1 = nn.ConvTranspose2d(num_classes, num_classes, 3, 2, padding=1) 174 | # self.cnv2 = nn.ConvTranspose2d(num_classes, num_classes, 3, 2, padding=1) 175 | # self.cnv3 = nn.ConvTranspose2d(num_classes, num_classes, 3, 2) 176 | 177 | self.conv1_1 = nn.Conv2d( 178 | num_classes, num_classes*2, kernel_size=5, padding=2) 179 | self.conv1_2 = nn.Conv2d( 180 | num_classes*2, num_classes*4, kernel_size=5, padding=2) 181 | self.conv2_1 = nn.Conv2d( 182 | num_classes, num_classes*2, kernel_size=5, padding=2) 183 | self.conv2_2 = nn.Conv2d( 184 | num_classes*2, num_classes*4, kernel_size=5, padding=2) 185 | self.conv3_1 = nn.Conv2d( 186 | num_classes, num_classes*2, kernel_size=5, padding=2) 187 | self.conv3_2 = nn.Conv2d( 188 | num_classes*2, num_classes*4, kernel_size=5, padding=(1, 2)) 189 | 190 | self.ps1 = nn.PixelShuffle(2) 191 | self.ps2 = nn.PixelShuffle(2) 192 | self.ps3 = nn.PixelShuffle(2) 193 | 194 | self.bn1 = nn.BatchNorm2d(num_classes) 195 | self.bn2 = nn.BatchNorm2d(num_classes) 196 | 197 | def forward(self, input): 198 | h = F.selu(self.conv1_1(input)) 199 | h = self.ps1(F.selu(self.conv1_2(h))) 200 | # h = F.selu(self.cnv1(input)) 201 | 202 | h = F.selu(self.conv2_1(h)) 203 | h = self.ps2(F.selu(self.conv2_2(h))) 204 | # h = F.selu(self.cnv2(h)) 205 | 206 | h = F.selu(self.conv3_1(h)) 207 | h = self.ps3(self.conv3_2(h)) 208 | return h 209 | 210 | 211 | class Net(nn.Module): 212 | def __init__(self, num_classes, **kwargs): 213 | super().__init__() 214 | self.encoder = Encoder(BasicBlock, [2, 2, 2, 2], **kwargs) 215 | self.decoder = Decoder(num_classes) 216 | 217 | def forward(self, input): 218 | h = self.encoder(input) 219 | decoder_in =\ 220 | self.encoder.feature_map\ 221 | * h.repeat(self.encoder.feature_map.shape[2:])\ 222 | .reshape(self.encoder.feature_map.shape) 223 | out = self.decoder(decoder_in) 224 | return out, self.encoder.se2, self.encoder.se1 225 | 226 | 227 | def enc_net(num_classes, pretrained=False, **kwargs): 228 | """Constructs a ResNet-18 model. 229 | Args: 230 | pretrained (bool): If True, returns a model pre-trained on ImageNet 231 | """ 232 | model = Net(num_classes, **kwargs) 233 | # if pretrained: 234 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 235 | return model 236 | -------------------------------------------------------------------------------- /src/model/resnet18.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | 6 | model_urls = { 7 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 8 | } 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1): 12 | """3x3 convolution with padding""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=1, bias=False) 15 | 16 | 17 | class BasicBlock(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, inplanes, planes, stride=1, downsample=None): 21 | super(BasicBlock, self).__init__() 22 | self.conv1 = conv3x3(inplanes, planes, stride) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.conv2 = conv3x3(planes, planes) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | self.downsample = downsample 28 | self.stride = stride 29 | 30 | def forward(self, x): 31 | residual = x 32 | 33 | out = self.conv1(x) 34 | out = self.bn1(out) 35 | out = self.relu(out) 36 | 37 | out = self.conv2(out) 38 | out = self.bn2(out) 39 | 40 | if self.downsample is not None: 41 | residual = self.downsample(x) 42 | 43 | out += residual 44 | out = self.relu(out) 45 | 46 | return out 47 | 48 | 49 | class Bottleneck(nn.Module): 50 | expansion = 4 51 | 52 | def __init__(self, inplanes, planes, stride=1, downsample=None): 53 | super(Bottleneck, self).__init__() 54 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 55 | self.bn1 = nn.BatchNorm2d(planes) 56 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 57 | padding=1, bias=False) 58 | self.bn2 = nn.BatchNorm2d(planes) 59 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 60 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.downsample = downsample 63 | self.stride = stride 64 | 65 | def forward(self, x): 66 | residual = x 67 | 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv3(out) 77 | out = self.bn3(out) 78 | 79 | if self.downsample is not None: 80 | residual = self.downsample(x) 81 | 82 | out += residual 83 | out = self.relu(out) 84 | 85 | return out 86 | 87 | 88 | class ResNet(nn.Module): 89 | 90 | def __init__(self, block, layers, num_classes=1000): 91 | self.inplanes = 64 92 | super(ResNet, self).__init__() 93 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 94 | bias=False) 95 | self.bn1 = nn.BatchNorm2d(64) 96 | self.relu = nn.ReLU(inplace=True) 97 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 98 | self.layer1 = self._make_layer(block, 64, layers[0]) 99 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 100 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 101 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 102 | self.avgpool = nn.AvgPool2d(7, stride=1) 103 | self.fc = nn.Linear(512 * block.expansion, num_classes) 104 | 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 108 | elif isinstance(m, nn.BatchNorm2d): 109 | nn.init.constant_(m.weight, 1) 110 | nn.init.constant_(m.bias, 0) 111 | 112 | def _make_layer(self, block, planes, blocks, stride=1): 113 | downsample = None 114 | if stride != 1 or self.inplanes != planes * block.expansion: 115 | downsample = nn.Sequential( 116 | nn.Conv2d(self.inplanes, planes * block.expansion, 117 | kernel_size=1, stride=stride, bias=False), 118 | nn.BatchNorm2d(planes * block.expansion), 119 | ) 120 | 121 | layers = [] 122 | layers.append(block(self.inplanes, planes, stride, downsample)) 123 | self.inplanes = planes * block.expansion 124 | for i in range(1, blocks): 125 | layers.append(block(self.inplanes, planes)) 126 | 127 | return nn.Sequential(*layers) 128 | 129 | def forward(self, x): 130 | x = self.conv1(x) 131 | x = self.bn1(x) 132 | x = self.relu(x) 133 | x = self.maxpool(x) 134 | 135 | x = self.layer1(x) 136 | x = self.layer2(x) 137 | x = self.layer3(x) 138 | x = self.layer4(x) 139 | 140 | x = self.avgpool(x) 141 | x = x.view(x.size(0), -1) 142 | x = self.fc(x) 143 | 144 | return x 145 | 146 | 147 | def resnet18(pretrained=False, **kwargs): 148 | """Constructs a ResNet-18 model. 149 | Args: 150 | pretrained (bool): If True, returns a model pre-trained on ImageNet 151 | """ 152 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 153 | if pretrained: 154 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 155 | return model 156 | -------------------------------------------------------------------------------- /src/model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def calculate_l1_loss(output, target, lagrange_coef=0.0005): 6 | l1_crit = nn.L1Loss(size_average=False) # SmoothL1Loss 7 | reg_loss = l1_crit(output.argmax(dim=1).float(), target.float()) 8 | 9 | return lagrange_coef * reg_loss 10 | 11 | 12 | def smooth_in(model): 13 | l_noise = [] 14 | for i, p in enumerate(model.parameters()): 15 | noise = torch.FloatTensor(p.shape).uniform_(-.01, .01) 16 | p.data -= noise 17 | l_noise.append(noise) 18 | # model.parameters()[i] = p 19 | return l_noise 20 | 21 | 22 | def smooth_out(model, l_noise): 23 | for i, (p, noise) in enumerate(zip(model.parameters(), l_noise)): 24 | p.data += noise 25 | l_noise.append(noise) 26 | # model.parameters()[i] = p 27 | # return l_noise 28 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import numpy as np 4 | import pandas as pd 5 | from skimage import io 6 | import matplotlib.pylab as plt 7 | import torch 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | import torchvision.transforms as transforms 11 | 12 | from model.enc_net import enc_net 13 | from model.utils import calculate_l1_loss 14 | from data.data_loader import CamVidDataset, loader 15 | from model.utils import smooth_in, smooth_out 16 | 17 | 18 | torch.manual_seed(555) 19 | 20 | data_root_dir = './data' 21 | label_table_name = 'label_color' 22 | color_data = pd.read_csv(Path(data_root_dir, label_table_name)) 23 | 24 | 25 | def main(args): 26 | model = enc_net(args.n_class) 27 | 28 | if Path(args.resume_model).exists(): 29 | print("load model:", args.resume_model) 30 | model.load_state_dict(torch.load(args.resume_model)) 31 | 32 | # setup optimizer 33 | optimizer = optim.Adam( 34 | model.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) 35 | # optimizer = optim.RMSprop(model.parameters(), lr=0.0001) 36 | # optimizer = optim.SGD(model.parameters(), lr=.01, momentum=.9, weight_decay=.01) 37 | 38 | train_image_names =\ 39 | [line.rstrip() for line in open(args.train_image_pointer_path)] 40 | test_image_names =\ 41 | [line.rstrip() for line in open(args.test_image_pointer_path)] 42 | 43 | resize_shape=(240, 180) 44 | 45 | train_dataset = CamVidDataset(train_image_names, args.root_dir) 46 | test_dataset = CamVidDataset(test_image_names, args.root_dir) 47 | train_loader = loader(train_dataset, args.batch_size) 48 | test_loader = loader(test_dataset, 1, shuffle=False) 49 | 50 | train(args, model, optimizer, train_loader) 51 | test(args, model, test_loader) 52 | 53 | 54 | def train(args, model, optimizer, data_loader): 55 | model.train() 56 | for epoch in range(args.epochs): 57 | for i, (data, target) in enumerate(data_loader): 58 | model.zero_grad() 59 | 60 | optimizer.zero_grad() 61 | output, se2, se1 = model(data) 62 | n_batch = output.shape[0] 63 | loss = F.nll_loss(F.log_softmax(output), target) 64 | # loss += calculate_l1_loss(output, target) 65 | 66 | exist_class = [[1 if c in target[i_batch].numpy() else 0 for c in range(32)] 67 | for i_batch in range(n_batch)] 68 | exist_class = torch.FloatTensor(exist_class) 69 | 70 | loss += F.mse_loss(se2, exist_class) 71 | loss += F.mse_loss(se1, exist_class) 72 | 73 | # with torch.no_grad(): 74 | # l_noise = smooth_in(model) 75 | loss.backward() 76 | # with torch.no_grad(): 77 | # smooth_out(model, l_noise) 78 | 79 | optimizer.step() 80 | print('[{}/{}][{}/{}] Loss: {:.4f}'.format( 81 | epoch, args.epochs, i, 82 | len(data_loader), loss.item())) 83 | 84 | # do checkpointing 85 | torch.save(model.state_dict(), 86 | '{}/encnet_ckpt.pth'.format(args.out_dir)) 87 | 88 | 89 | def test(args, model, data_loader): 90 | model.eval() 91 | test_loss = 0 92 | correct = 0 93 | with torch.no_grad(): 94 | for i_batch, (data, target) in enumerate(data_loader): 95 | output, se2, se1 = model(data) 96 | # sum up batch loss 97 | test_loss += torch.mean(F.nll_loss( 98 | output, target, size_average=False)).item() 99 | # get the index of the max log-probability 100 | pred = output.argmax(1) 101 | correct += pred.eq(target.view_as(pred)).sum().item() 102 | 103 | restoration(pred.numpy(), i_batch, args.n_class) 104 | 105 | test_loss /= len(data_loader.dataset) 106 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n' 107 | .format(test_loss, correct, len(data_loader.dataset), 108 | 100. * correct / len(data_loader.dataset))) 109 | 110 | 111 | def restoration(result_labels, i_batch, n_class, 112 | output_img_path='./results/images'): 113 | for i_image, labels in enumerate(result_labels): 114 | print("---------------labels") 115 | print(labels.shape) 116 | h, w = labels.shape 117 | # 出力ラベルから画像への変換 118 | img = np.zeros((h, w, 3)) 119 | for category in range(n_class): 120 | idx = np.where(labels == category) # indexはタプルに格納される 121 | if len(idx[0]) > 0: 122 | color = color_data.ix[category] 123 | img[idx[0], idx[1], :] = [color['r'], color['g'], color['b']] 124 | img = img.astype(np.uint8) 125 | io.imsave(output_img_path+'/test_result_{}.jpg'.format(str(i_batch).zfill(5)), img) 126 | # plt.figure() 127 | io.imshow(img) 128 | # plt.show() 129 | 130 | 131 | if __name__ == '__main__': 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument('--root_dir', default='./data/CamSeq01', help='path to dataset') 134 | parser.add_argument('--n-class', type=int, default=32, help='number of class') 135 | parser.add_argument('--train-image-pointer-path', default='./data/train_image_pointer', help='path to train image pointer') 136 | parser.add_argument('--test-image-pointer-path', default='./data/test_image_pointer', help='path to test image pointer') 137 | parser.add_argument('--resume-model', default='./results/_encnet_ckpt.pth', help='path to trained model') 138 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 139 | parser.add_argument('--batch-size', type=int, default=16, help='input batch size') 140 | parser.add_argument('--image-size', type=int, default=256, help='the height / width of the input image to network') 141 | parser.add_argument('--epochs', type=int, default=200, help='number of epochs to train for') 142 | parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') 143 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 144 | parser.add_argument('--out-dir', default='./results', help='folder to output images and model checkpoints') 145 | args = parser.parse_args() 146 | Path(args.out_dir).mkdir(parents=True, exist_ok=True), 147 | 148 | main(args) 149 | --------------------------------------------------------------------------------