├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── datasets ├── __init__.py └── nyudv2.py ├── draw_net.py ├── models.py └── run.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # IPython 77 | profile_default/ 78 | ipython_config.py 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .dmypy.json 111 | dmypy.json 112 | 113 | # Pyre type checker 114 | .pyre/ 115 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Benny 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3DGNN for RGB-D segmentation 2 | This is the Pytorch implementation of [3D Graph Neural Networks for RGBD Semantic Segmentation](http://openaccess.thecvf.com/content_ICCV_2017/papers/Qi_3D_Graph_Neural_ICCV_2017_paper.pdf): 3 | 4 | ### Data Preparation 5 | 1. Download NYU_Depth_V2 dataset from [here](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) and select scenes and save as `./datasets/data/nyu_depth_v2_labeled.mat` 6 | 2. Transfer depth images to hha by yourself from [here](https://github.com/charlesCXK/Depth2HHA) and save in `./datasets/data/hha/`. 7 | 8 | ### Emviroment 9 | Required CUDA (8.0) + pytorch 0.4.1 10 | 11 | 12 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | '''Data Loader parameter''' 2 | # Multiple threads loading data 3 | workers_tr = 4 4 | workers_va = 4 5 | # Data augmentation 6 | flip_prob = 0.5 7 | crop_size = 0 8 | 9 | '''GNN parameter''' 10 | use_gnn = True 11 | gnn_iterations = 3 12 | gnn_k = 64 13 | mlp_num_layers = 1 14 | 15 | '''Model parameter''' 16 | use_bootstrap_loss = False 17 | bootstrap_rate = 0.25 18 | use_gpu = True 19 | class_weights = [0.0] + [1.0 for i in range(13)] 20 | nclasses = len(class_weights) 21 | 22 | '''Optimizer parameter''' 23 | base_initial_lr = 5e-4 24 | gnn_initial_lr = 1e-3 25 | betas = [0.9, 0.999] 26 | eps = 1e-08 27 | weight_decay = 1e-4 28 | lr_schedule_type = 'exp' 29 | lr_decay = 0.9 30 | lr_patience = 10 31 | 32 | ENCODER_PARAMS = [ 33 | { 34 | 'internal_scale': 4, 35 | 'use_relu': True, 36 | 'asymmetric': False, 37 | 'dilated': False, 38 | 'input_channels': 32, 39 | 'output_channels': 64, 40 | 'downsample': True, 41 | 'dropout_prob': 0.01 42 | }, 43 | { 44 | 'internal_scale': 4, 45 | 'use_relu': True, 46 | 'asymmetric': False, 47 | 'dilated': False, 48 | 'input_channels': 64, 49 | 'output_channels': 64, 50 | 'downsample': False, 51 | 'dropout_prob': 0.01 52 | }, 53 | { 54 | 'internal_scale': 4, 55 | 'use_relu': True, 56 | 'asymmetric': False, 57 | 'dilated': False, 58 | 'input_channels': 64, 59 | 'output_channels': 64, 60 | 'downsample': False, 61 | 'dropout_prob': 0.01 62 | }, 63 | { 64 | 'internal_scale': 4, 65 | 'use_relu': True, 66 | 'asymmetric': False, 67 | 'dilated': False, 68 | 'input_channels': 64, 69 | 'output_channels': 64, 70 | 'downsample': False, 71 | 'dropout_prob': 0.01 72 | }, 73 | { 74 | 'internal_scale': 4, 75 | 'use_relu': True, 76 | 'asymmetric': False, 77 | 'dilated': False, 78 | 'input_channels': 64, 79 | 'output_channels': 64, 80 | 'downsample': False, 81 | 'dropout_prob': 0.01 82 | }, 83 | { 84 | 'internal_scale': 4, 85 | 'use_relu': True, 86 | 'asymmetric': False, 87 | 'dilated': False, 88 | 'input_channels': 64, 89 | 'output_channels': 128, 90 | 'downsample': True, 91 | 'dropout_prob': 0.1 92 | }, 93 | { 94 | 'internal_scale': 4, 95 | 'use_relu': True, 96 | 'asymmetric': False, 97 | 'dilated': False, 98 | 'input_channels': 128, 99 | 'output_channels': 128, 100 | 'downsample': False, 101 | 'dropout_prob': 0.1 102 | }, 103 | { 104 | 'internal_scale': 4, 105 | 'use_relu': True, 106 | 'asymmetric': False, 107 | 'dilated': False, # 2 108 | 'input_channels': 128, 109 | 'output_channels': 128, 110 | 'downsample': False, 111 | 'dropout_prob': 0.1 112 | }, 113 | { 114 | 'internal_scale': 4, 115 | 'use_relu': True, 116 | 'asymmetric': False, 117 | 'dilated': False, 118 | 'input_channels': 128, 119 | 'output_channels': 128, 120 | 'downsample': False, 121 | 'dropout_prob': 0.1 122 | }, 123 | { 124 | 'internal_scale': 4, 125 | 'use_relu': True, 126 | 'asymmetric': False, 127 | 'dilated': False, # 4 128 | 'input_channels': 128, 129 | 'output_channels': 128, 130 | 'downsample': False, 131 | 'dropout_prob': 0.1 132 | }, 133 | { 134 | 'internal_scale': 4, 135 | 'use_relu': True, 136 | 'asymmetric': False, 137 | 'dilated': False, 138 | 'input_channels': 128, 139 | 'output_channels': 128, 140 | 'downsample': False, 141 | 'dropout_prob': 0.1 142 | }, 143 | { 144 | 'internal_scale': 4, 145 | 'use_relu': True, 146 | 'asymmetric': False, 147 | 'dilated': False, # 8 148 | 'input_channels': 128, 149 | 'output_channels': 128, 150 | 'downsample': False, 151 | 'dropout_prob': 0.1 152 | }, 153 | { 154 | 'internal_scale': 4, 155 | 'use_relu': True, 156 | 'asymmetric': False, 157 | 'dilated': False, 158 | 'input_channels': 128, 159 | 'output_channels': 128, 160 | 'downsample': False, 161 | 'dropout_prob': 0.1 162 | }, 163 | { 164 | 'internal_scale': 4, 165 | 'use_relu': True, 166 | 'asymmetric': False, 167 | 'dilated': False, # 16 168 | 'input_channels': 128, 169 | 'output_channels': 128, 170 | 'downsample': False, 171 | 'dropout_prob': 0.1 172 | }, 173 | { 174 | 'internal_scale': 4, 175 | 'use_relu': True, 176 | 'asymmetric': False, 177 | 'dilated': False, 178 | 'input_channels': 128, 179 | 'output_channels': 128, 180 | 'downsample': False, 181 | 'dropout_prob': 0.1 182 | }, 183 | { 184 | 'internal_scale': 4, 185 | 'use_relu': True, 186 | 'asymmetric': False, 187 | 'dilated': False, 188 | 'input_channels': 128, 189 | 'output_channels': 128, 190 | 'downsample': False, 191 | 'dropout_prob': 0.1 192 | }, 193 | { 194 | 'internal_scale': 4, 195 | 'use_relu': True, 196 | 'asymmetric': False, 197 | 'dilated': False, # 2 198 | 'input_channels': 128, 199 | 'output_channels': 128, 200 | 'downsample': False, 201 | 'dropout_prob': 0.1 202 | }, 203 | { 204 | 'internal_scale': 4, 205 | 'use_relu': True, 206 | 'asymmetric': False, 207 | 'dilated': False, 208 | 'input_channels': 128, 209 | 'output_channels': 128, 210 | 'downsample': False, 211 | 'dropout_prob': 0.1 212 | }, 213 | { 214 | 'internal_scale': 4, 215 | 'use_relu': True, 216 | 'asymmetric': False, 217 | 'dilated': False, # 4 218 | 'input_channels': 128, 219 | 'output_channels': 128, 220 | 'downsample': False, 221 | 'dropout_prob': 0.1 222 | }, 223 | { 224 | 'internal_scale': 4, 225 | 'use_relu': True, 226 | 'asymmetric': False, 227 | 'dilated': False, 228 | 'input_channels': 128, 229 | 'output_channels': 128, 230 | 'downsample': False, 231 | 'dropout_prob': 0.1 232 | }, 233 | { 234 | 'internal_scale': 4, 235 | 'use_relu': True, 236 | 'asymmetric': False, 237 | 'dilated': False, # 8 238 | 'input_channels': 128, 239 | 'output_channels': 128, 240 | 'downsample': False, 241 | 'dropout_prob': 0.1 242 | }, 243 | { 244 | 'internal_scale': 4, 245 | 'use_relu': True, 246 | 'asymmetric': False, 247 | 'dilated': False, 248 | 'input_channels': 128, 249 | 'output_channels': 128, 250 | 'downsample': False, 251 | 'dropout_prob': 0.1 252 | }, 253 | { 254 | 'internal_scale': 4, 255 | 'use_relu': True, 256 | 'asymmetric': False, 257 | 'dilated': False, # 16 258 | 'input_channels': 128, 259 | 'output_channels': 128, 260 | 'downsample': False, 261 | 'dropout_prob': 0.1 262 | } 263 | ] 264 | 265 | DECODER_PARAMS = [ 266 | { 267 | 'input_channels': 128, 268 | 'output_channels': 128, 269 | 'upsample': False, 270 | 'pooling_module': None 271 | }, 272 | { 273 | 'input_channels': 128, 274 | 'output_channels': 64, 275 | 'upsample': True, 276 | 'pooling_module': None 277 | }, 278 | { 279 | 'input_channels': 64, 280 | 'output_channels': 64, 281 | 'upsample': False, 282 | 'pooling_module': None 283 | }, 284 | { 285 | 'input_channels': 64, 286 | 'output_channels': 64, 287 | 'upsample': False, 288 | 'pooling_module': None 289 | }, 290 | { 291 | 'input_channels': 64, 292 | 'output_channels': 32, 293 | 'upsample': True, 294 | 'pooling_module': None 295 | }, 296 | { 297 | 'input_channels': 32, 298 | 'output_channels': 32, 299 | 'upsample': False, 300 | 'pooling_module': None 301 | } 302 | ] -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanx27/3DGNN_pytorch/b5e0188e56f926cff9b2c9a68bedf42cb3b42d2f/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/nyudv2.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import glob 3 | import numpy as np 4 | import cv2 5 | import h5py 6 | 7 | 8 | class Dataset(Dataset): 9 | def __init__(self, flip_prob=None, crop_type=None, crop_size=0): 10 | 11 | self.flip_prob = flip_prob 12 | self.crop_type = crop_type 13 | self.crop_size = crop_size 14 | 15 | data_path = 'datasets/data/' 16 | data_file = 'nyu_depth_v2_labeled.mat' 17 | 18 | # read mat file 19 | print("Reading .mat file...") 20 | f = h5py.File(data_path + data_file) 21 | 22 | # as it turns out, trying to pickle this is a shit idea :D 23 | 24 | rgb_images_fr = np.transpose(f['images'], [0, 2, 3, 1]).astype(np.float32) 25 | label_images_fr = np.array(f['labels']) 26 | 27 | f.close() 28 | 29 | self.rgb_images = rgb_images_fr 30 | self.label_images = label_images_fr 31 | 32 | 33 | def __len__(self): 34 | return len(self.rgb_images) 35 | 36 | def __getitem__(self, idx): 37 | rgb = self.rgb_images[idx].astype(np.float32) 38 | hha = np.transpose(cv2.imread("datasets/data/hha/" + str(idx+1) + ".png", cv2.COLOR_BGR2RGB), [1, 0, 2]) 39 | rgb_hha = np.concatenate([rgb, hha], axis=2).astype(np.float32) 40 | label = self.label_images[idx].astype(np.float32) 41 | label[label >= 14] = 0 42 | xy = np.zeros_like(rgb)[:,:,0:2].astype(np.float32) 43 | 44 | # random crop 45 | if self.crop_type is not None and self.crop_size > 0: 46 | max_margin = rgb_hha.shape[0] - self.crop_size 47 | if max_margin == 0: # crop is original size, so nothing to crop 48 | self.crop_type = None 49 | elif self.crop_type == 'Center': 50 | rgb_hha = rgb[max_margin // 2:-max_margin // 2, max_margin // 2:-max_margin // 2, :] 51 | label = label[max_margin // 2:-max_margin // 2, max_margin // 2:-max_margin // 2] 52 | xy = xy[max_margin // 2:-max_margin // 2, max_margin // 2:-max_margin // 2, :] 53 | elif self.crop_type == 'Random': 54 | x_ = np.random.randint(0, max_margin) 55 | y_ = np.random.randint(0, max_margin) 56 | rgb_hha = rgb_hha[y_:y_ + self.crop_size, x_:x_ + self.crop_size, :] 57 | label = label[y_:y_ + self.crop_size, x_:x_ + self.crop_size] 58 | xy = xy[y_:y_ + self.crop_size, x_:x_ + self.crop_size, :] 59 | else: 60 | print('Bad crop') # TODO make this more like, you know, good software 61 | exit(0) 62 | 63 | # random flip 64 | if self.flip_prob is not None: 65 | if np.random.random() > self.flip_prob: 66 | rgb_hha = np.fliplr(rgb_hha).copy() 67 | label = np.fliplr(label).copy() 68 | xy = np.fliplr(xy).copy() 69 | 70 | return rgb_hha, label, xy 71 | 72 | -------------------------------------------------------------------------------- /draw_net.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import matplotlib.pyplot as plt 4 | from torch.utils.data import Dataset 5 | import torch 6 | import numpy as np 7 | import cv2 8 | import h5py 9 | from models import Model 10 | from torch.autograd import Variable 11 | 12 | def show_grapth(img): 13 | if len(img.shape) == 3: 14 | b, g, r = cv2.split(img) 15 | img = cv2.merge([r, g, b]) 16 | plt.imshow(img) 17 | plt.axis('off') 18 | plt.show() 19 | 20 | data_path = 'datasets/data/' 21 | data_file = 'nyu_depth_v2_labeled.mat' 22 | f = h5py.File(data_path + data_file) 23 | rgb_images_fr = np.transpose(f['images'], [0, 2, 3, 1]).astype(np.float32) 24 | depth_images_fr = np.array((f['depths'])) 25 | label_images_fr = np.array(f['labels']) 26 | class_name = f['names'] 27 | 28 | show_grapth(rgb_images_fr[0]) 29 | show_grapth(depth_images_fr[0]) 30 | show_grapth(label_images_fr[0]) 31 | 32 | rgb = rgb_images_fr[0].astype(np.float32) 33 | hha = rgb_images_fr[0].astype(np.float32) 34 | rgb_hha = np.concatenate([rgb, hha], axis=2).astype(np.float32) 35 | rgb_hha = torch.Tensor(rgb_hha) 36 | rgb_hha = rgb_hha.unsqueeze(0) 37 | xy = torch.Tensor(np.zeros_like(rgb)[:,:,0:2].astype(np.float32)) 38 | xy = xy.unsqueeze(0) 39 | xy = xy.permute(0, 3, 1, 2).contiguous() 40 | input = rgb_hha.permute(0, 3, 1, 2).contiguous() 41 | input = input.type(torch.FloatTensor) 42 | 43 | model = Model(14, 1,use_gpu = False) 44 | output = model(Variable(input), gnn_iterations=3, k=64, xy=xy, use_gnn=True) 45 | 46 | from torchviz import make_dot 47 | draw = make_dot(output, params=dict(model.named_parameters())) 48 | draw.view() 49 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.utils import _pair, _quadruple 5 | from torch.autograd import Variable 6 | import config 7 | 8 | ENCODER_PARAMS = config.ENCODER_PARAMS 9 | DECODER_PARAMS = config.DECODER_PARAMS 10 | 11 | class MedianPool2d(nn.Module): 12 | """ Median pool (usable as median filter when stride=1) module. 13 | 14 | Args: 15 | kernel_size: size of pooling kernel, int or 2-tuple 16 | stride: pool stride, int or 2-tuple 17 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 18 | same: override padding and enforce same padding, boolean 19 | """ 20 | 21 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 22 | super(MedianPool2d, self).__init__() 23 | self.k = _pair(kernel_size) 24 | self.stride = _pair(stride) 25 | self.padding = _quadruple(padding) # convert to l, r, g, b 26 | self.same = same 27 | 28 | def _padding(self, x): 29 | if self.same: 30 | ih, iw = x.size()[2:] 31 | if ih % self.stride[0] == 0: 32 | ph = max(self.k[0] - self.stride[0], 0) 33 | else: 34 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 35 | if iw % self.stride[1] == 0: 36 | pw = max(self.k[1] - self.stride[1], 0) 37 | else: 38 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 39 | pl = pw // 2 40 | pr = pw - pl 41 | pt = ph // 2 42 | pb = ph - pt 43 | padding = (pl, pr, pt, pb) 44 | else: 45 | padding = self.padding 46 | return padding 47 | 48 | def forward(self, x): 49 | # using existing pytorch functions and tensor ops so that we get autograd, 50 | # would likely be more efficient to implement from scratch at C/Cuda level 51 | x = F.pad(x, self._padding(x), mode='reflect') 52 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 53 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 54 | return x 55 | 56 | 57 | class EnetInitialBlock(nn.Module): 58 | def __init__(self): 59 | super().__init__() 60 | 61 | self.conv = nn.Conv2d( 62 | 6, 32, (3, 3), 63 | stride=2, padding=1, bias=True) 64 | # self.pool = nn.MaxPool2d(2, stride=2) 65 | self.batch_norm = nn.BatchNorm2d(32, eps=1e-3) 66 | self.actf = nn.PReLU() 67 | 68 | def forward(self, input): 69 | output = self.conv(input) 70 | output = self.batch_norm(output) 71 | return self.actf(output) 72 | 73 | 74 | class EnetEncoderMainPath(nn.Module): 75 | def __init__(self, internal_scale=None, use_relu=None, asymmetric=None, dilated=None, input_channels=None, 76 | output_channels=None, downsample=None, dropout_prob=None): 77 | super().__init__() 78 | 79 | internal_channels = output_channels // internal_scale 80 | input_stride = downsample and 2 or 1 81 | 82 | self.__dict__.update(locals()) 83 | del self.self 84 | 85 | self.input_conv = nn.Conv2d( 86 | input_channels, internal_channels, input_stride, 87 | stride=input_stride, padding=0, bias=False) 88 | 89 | self.input_batch_norm = nn.BatchNorm2d(internal_channels, eps=1e-03) 90 | 91 | self.middle_conv = nn.Conv2d( 92 | internal_channels, internal_channels, 3, stride=1, bias=True, 93 | dilation=1 if (not dilated or dilated is None) else dilated, 94 | padding=1 if (not dilated or dilated is None) else dilated) 95 | 96 | self.middle_batch_norm = nn.BatchNorm2d(internal_channels, eps=1e-03) 97 | 98 | self.output_conv = nn.Conv2d( 99 | internal_channels, output_channels, 1, 100 | stride=1, padding=0, bias=False) 101 | 102 | self.output_batch_norm = nn.BatchNorm2d(output_channels, eps=1e-03) 103 | 104 | self.dropout = nn.Dropout2d(dropout_prob) 105 | 106 | self.input_actf = nn.PReLU() 107 | self.middle_actf = nn.PReLU() 108 | 109 | def forward(self, input): 110 | output = self.input_conv(input) 111 | 112 | output = self.input_batch_norm(output) 113 | 114 | output = self.input_actf(output) 115 | 116 | output = self.middle_conv(output) 117 | 118 | output = self.middle_batch_norm(output) 119 | 120 | output = self.middle_actf(output) 121 | 122 | output = self.output_conv(output) 123 | 124 | output = self.output_batch_norm(output) 125 | 126 | output = self.dropout(output) 127 | 128 | return output 129 | 130 | 131 | class EnetEncoderOtherPath(nn.Module): 132 | def __init__(self, internal_scale=None, use_relu=None, asymmetric=None, dilated=None, input_channels=None, 133 | output_channels=None, downsample=None, **kwargs): 134 | super().__init__() 135 | 136 | self.__dict__.update(locals()) 137 | del self.self 138 | 139 | if downsample: 140 | self.pool = nn.MaxPool2d(2, stride=2, return_indices=True) 141 | 142 | def forward(self, input): 143 | output = input 144 | 145 | if self.downsample: 146 | output, self.indices = self.pool(input) 147 | 148 | if self.output_channels != self.input_channels: 149 | new_size = [1, 1, 1, 1] 150 | new_size[1] = self.output_channels // self.input_channels 151 | output = output.repeat(*new_size) 152 | 153 | return output 154 | 155 | 156 | class EnetEncoderModule(nn.Module): 157 | def __init__(self, **kwargs): 158 | super().__init__() 159 | self.main = EnetEncoderMainPath(**kwargs) 160 | self.other = EnetEncoderOtherPath(**kwargs) 161 | self.actf = nn.PReLU() 162 | 163 | def forward(self, input): 164 | main = self.main(input) 165 | other = self.other(input) 166 | # print("EnetEncoderModule main size:", main.size()) 167 | # print("EnetEncoderModule other size:", other.size()) 168 | return self.actf(main + other) 169 | 170 | 171 | class EnetEncoder(nn.Module): 172 | def __init__(self, params, nclasses): 173 | super().__init__() 174 | self.initial_block = EnetInitialBlock() 175 | 176 | self.layers = [] 177 | for i, params in enumerate(params): 178 | layer_name = 'encoder_{:02d}'.format(i) 179 | layer = EnetEncoderModule(**params) 180 | super().__setattr__(layer_name, layer) 181 | self.layers.append(layer) 182 | 183 | self.output_conv = nn.Conv2d( 184 | 128, nclasses, 1, 185 | stride=1, padding=0, bias=True) 186 | 187 | def forward(self, input): 188 | output = self.initial_block(input) 189 | 190 | for layer in self.layers: 191 | output = layer(output) 192 | 193 | return output 194 | 195 | 196 | class EnetDecoderMainPath(nn.Module): 197 | def __init__(self, input_channels=None, output_channels=None, upsample=None, pooling_module=None): 198 | super().__init__() 199 | 200 | internal_channels = output_channels // 4 201 | input_stride = 2 if upsample is True else 1 202 | 203 | self.__dict__.update(locals()) 204 | del self.self 205 | 206 | self.input_conv = nn.Conv2d( 207 | input_channels, internal_channels, 1, 208 | stride=1, padding=0, bias=False) 209 | 210 | self.input_batch_norm = nn.BatchNorm2d(internal_channels, eps=1e-03) 211 | 212 | if not upsample: 213 | self.middle_conv = nn.Conv2d( 214 | internal_channels, internal_channels, 3, 215 | stride=1, padding=1, bias=True) 216 | else: 217 | self.middle_conv = nn.ConvTranspose2d( 218 | internal_channels, internal_channels, 3, 219 | stride=2, padding=1, output_padding=1, 220 | bias=True) 221 | 222 | self.middle_batch_norm = nn.BatchNorm2d(internal_channels, eps=1e-03) 223 | 224 | self.output_conv = nn.Conv2d( 225 | internal_channels, output_channels, 1, 226 | stride=1, padding=0, bias=False) 227 | 228 | self.output_batch_norm = nn.BatchNorm2d(output_channels, eps=1e-03) 229 | 230 | self.input_actf = nn.PReLU() 231 | self.middle_actf = nn.PReLU() 232 | 233 | def forward(self, input): 234 | output = self.input_conv(input) 235 | 236 | output = self.input_batch_norm(output) 237 | 238 | output = self.input_actf(output) 239 | 240 | output = self.middle_conv(output) 241 | 242 | output = self.middle_batch_norm(output) 243 | 244 | output = self.middle_actf(output) 245 | 246 | output = self.output_conv(output) 247 | 248 | output = self.output_batch_norm(output) 249 | 250 | return output 251 | 252 | 253 | class EnetDecoderOtherPath(nn.Module): 254 | def __init__(self, input_channels=None, output_channels=None, upsample=None, pooling_module=None): 255 | super().__init__() 256 | 257 | self.__dict__.update(locals()) 258 | del self.self 259 | 260 | if output_channels != input_channels or upsample: 261 | self.conv = nn.Conv2d( 262 | input_channels, output_channels, 1, 263 | stride=1, padding=0, bias=False) 264 | self.batch_norm = nn.BatchNorm2d(output_channels, eps=1e-03) 265 | if upsample and pooling_module: 266 | self.unpool = nn.MaxUnpool2d(2, stride=2, padding=0) 267 | 268 | def forward(self, input): 269 | output = input 270 | 271 | if self.output_channels != self.input_channels or self.upsample: 272 | output = self.conv(output) 273 | output = self.batch_norm(output) 274 | if self.upsample and self.pooling_module: 275 | output_size = list(output.size()) 276 | output_size[2] *= 2 277 | output_size[3] *= 2 278 | output = self.unpool( 279 | output, self.pooling_module.indices, 280 | output_size=output_size) 281 | 282 | return output 283 | 284 | 285 | class EnetDecoderModule(nn.Module): 286 | def __init__(self, **kwargs): 287 | super().__init__() 288 | 289 | self.main = EnetDecoderMainPath(**kwargs) 290 | self.other = EnetDecoderOtherPath(**kwargs) 291 | self.actf = nn.PReLU() 292 | 293 | def forward(self, input): 294 | main = self.main(input) 295 | other = self.other(input) 296 | # print("EnetDecoderModule main size:", main.size()) 297 | # print("EnetDecoderModule other size:", other.size()) 298 | return self.actf(main + other) 299 | 300 | 301 | class EnetDecoder(nn.Module): 302 | def __init__(self, params, nclasses, encoder): 303 | super().__init__() 304 | 305 | self.encoder = encoder 306 | 307 | self.pooling_modules = [] 308 | 309 | for mod in self.encoder.modules(): 310 | try: 311 | if mod.other.downsample: 312 | self.pooling_modules.append(mod.other) 313 | except AttributeError: 314 | pass 315 | 316 | self.layers = [] 317 | for i, params in enumerate(params): 318 | if params['upsample']: 319 | params['pooling_module'] = self.pooling_modules.pop(-1) 320 | layer = EnetDecoderModule(**params) 321 | self.layers.append(layer) 322 | layer_name = 'decoder{:02d}'.format(i) 323 | super().__setattr__(layer_name, layer) 324 | 325 | self.output_conv = nn.ConvTranspose2d( 326 | 32, nclasses, 2, 327 | stride=2, padding=0, output_padding=0, bias=True) 328 | 329 | def forward(self, input): 330 | output = input 331 | 332 | for layer in self.layers: 333 | output = layer(output) 334 | 335 | output = self.output_conv(output) 336 | 337 | return output 338 | 339 | 340 | class EnetGnn(nn.Module): 341 | def __init__(self, mlp_num_layers,use_gpu): 342 | super().__init__() 343 | 344 | self.median_pool = MedianPool2d(kernel_size=8, stride=8, padding=0, same=False) 345 | 346 | self.g_rnn_layers = nn.ModuleList([nn.Linear(128, 128) for l in range(mlp_num_layers)]) 347 | self.g_rnn_actfs = nn.ModuleList([nn.PReLU() for l in range(mlp_num_layers)]) 348 | self.q_rnn_layer = nn.Linear(256, 128) 349 | self.q_rnn_actf = nn.PReLU() 350 | self.output_conv = nn.Conv2d(256, 128, 3, stride=1, padding=1, bias=True) 351 | self.use_gpu = use_gpu 352 | 353 | # adapted from https://discuss.pytorch.org/t/build-your-own-loss-function-in-pytorch/235/6 354 | # (x - y)^2 = x^2 - 2*x*y + y^2 355 | def get_knn_indices(self, batch_mat, k): 356 | r = torch.bmm(batch_mat, batch_mat.permute(0, 2, 1)) 357 | N = r.size()[0] 358 | HW = r.size()[1] 359 | if self.use_gpu: 360 | batch_indices = torch.zeros((N, HW, k)).cuda() 361 | else: 362 | batch_indices = torch.zeros((N, HW, k)) 363 | for idx, val in enumerate(r): 364 | # get the diagonal elements 365 | diag = val.diag().unsqueeze(0) 366 | diag = diag.expand_as(val) 367 | # compute the distance matrix 368 | D = (diag + diag.t() - 2 * val).sqrt() 369 | topk, indices = torch.topk(D, k=k, largest=False) 370 | batch_indices[idx] = indices.data 371 | return batch_indices 372 | 373 | def forward(self, cnn_encoder_output, original_input, gnn_iterations, k, xy): 374 | 375 | # extract for convenience 376 | N = cnn_encoder_output.size()[0] 377 | C = cnn_encoder_output.size()[1] 378 | H = cnn_encoder_output.size()[2] 379 | W = cnn_encoder_output.size()[3] 380 | K = k 381 | 382 | # extract and resize depth image as horizontal disparity channel from HHA encoded image 383 | depth = original_input[:, 3, :, :] # N 8H 8W 384 | depth = depth.view(depth.size()[0], 1, depth.size()[1], depth.size()[2]) 385 | depth_resize = self.median_pool(depth) # N 1 H W 386 | x_coords = xy[:, 0, :, :] 387 | x_coords = x_coords.view(x_coords.size()[0], 1, x_coords.size()[1], x_coords.size()[2]) 388 | y_coords = xy[:, 1, :, :] 389 | y_coords = y_coords.view(y_coords.size()[0], 1, y_coords.size()[1], y_coords.size()[2]) 390 | x_coords = self.median_pool(x_coords) # N 1 H W 391 | y_coords = self.median_pool(y_coords) # N 1 H W 392 | 393 | # 3D projection --> point cloud 394 | proj_3d = torch.cat((x_coords, y_coords, depth_resize), 1) 395 | proj_3d = proj_3d.view(N, 3, H*W).permute(0, 2, 1).contiguous() # N H*W 3 396 | 397 | # get k nearest neighbors 398 | knn = self.get_knn_indices(proj_3d, k=K) # N HW K 399 | knn = knn.view(N, H*W*K).long() # N HWK 400 | 401 | # prepare CNN encoded features for RNN 402 | h = cnn_encoder_output # N C H W 403 | h = h.permute(0, 2, 3, 1).contiguous() # N H W C 404 | h = h.view(N, (H * W), C) # N HW C 405 | 406 | # aggregate and iterate messages in m, keep original CNN features h for later 407 | m = h.clone() # N HW C 408 | 409 | # loop over timestamps to unroll 410 | for i in range(gnn_iterations): 411 | # do this for every samplein batch, not nice, but I don't know how to use index_select batchwise 412 | for n in range(N): 413 | # fetch features from nearest neighbors 414 | neighbor_features = torch.index_select(h[n], 0, Variable(knn[n])).view(H*W, K, C) # H*W K C 415 | # run neighbor features through MLP g and activation function 416 | for idx, g_layer in enumerate(self.g_rnn_layers): 417 | neighbor_features = self.g_rnn_actfs[idx](g_layer(neighbor_features)) # HW K C 418 | # average over activated neighbors 419 | m[n] = torch.mean(neighbor_features, dim=1) # HW C 420 | 421 | # concatenate current state with messages 422 | concat = torch.cat((h, m), 2) # N HW 2C 423 | 424 | # get new features by running MLP q and activation function 425 | h = self.q_rnn_actf(self.q_rnn_layer(concat)) # N HW C 426 | 427 | # format RNN activations back to image, concatenate original CNN embedding, return 428 | h = h.view(N, H, W, C).permute(0, 3, 1, 2).contiguous() # N C H W 429 | output = self.output_conv(torch.cat((cnn_encoder_output, h), 1)) # N 2C H W 430 | return output 431 | 432 | 433 | class Model(nn.Module): 434 | def __init__(self, nclasses, mlp_num_layers,use_gpu): 435 | super().__init__() 436 | 437 | self.encoder = EnetEncoder(ENCODER_PARAMS, nclasses) 438 | self.gnn = EnetGnn(mlp_num_layers,use_gpu) 439 | self.decoder = EnetDecoder(DECODER_PARAMS, nclasses, self.encoder) 440 | 441 | def forward(self, input, gnn_iterations, k, xy, use_gnn, only_encode=False): 442 | x = self.encoder.forward(input) 443 | 444 | if only_encode: 445 | return x 446 | 447 | if use_gnn: 448 | x = self.gnn.forward(x, input, gnn_iterations, k, xy) 449 | 450 | return self.decoder.forward(x) 451 | 452 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import sys 4 | import time 5 | import numpy as np 6 | import datetime 7 | import logging 8 | import torch 9 | import torch.optim 10 | import torch.nn as nn 11 | from torch.utils.data import DataLoader 12 | from datasets import nyudv2 13 | from models import Model 14 | import config 15 | from tqdm import tqdm 16 | import argparse 17 | 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | torch.backends.cudnn.benchmark = True 20 | 21 | def parse_args(): 22 | '''PARAMETERS''' 23 | parser = argparse.ArgumentParser('3dgnn') 24 | parser.add_argument('--num_epochs', default=50,type=int, 25 | help='Number of epoch') 26 | parser.add_argument('--batchsize', type=int, default=4, 27 | help='batch size in training') 28 | parser.add_argument('--pretrain', type=str, default=None, 29 | help='Direction for pretrained weight') 30 | parser.add_argument('--gpu', type=str, default='0', 31 | help='specify gpu device') 32 | 33 | return parser.parse_args() 34 | 35 | def main(args): 36 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 37 | logger = logging.getLogger('3dgnn') 38 | log_path = './experiment/'+ str(datetime.datetime.now().strftime('%Y-%m-%d-%H')).replace(' ', '/') + '/' 39 | print('log path is:',log_path) 40 | if not os.path.exists(log_path): 41 | os.makedirs(log_path) 42 | os.makedirs(log_path + 'save/') 43 | hdlr = logging.FileHandler(log_path + 'log.txt') 44 | formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') 45 | hdlr.setFormatter(formatter) 46 | logger.addHandler(hdlr) 47 | logger.setLevel(logging.INFO) 48 | logger.info("Loading data...") 49 | print("Loading data...") 50 | 51 | label_to_idx = {'': 0, 'beam': 1, 'board': 2, 'bookcase': 3, 'ceiling': 4, 'chair': 5, 'clutter': 6, 52 | 'column': 7, 53 | 'door': 8, 'floor': 9, 'sofa': 10, 'table': 11, 'wall': 12, 'window': 13} 54 | 55 | idx_to_label = {0: '', 1: 'beam', 2: 'board', 3: 'bookcase', 4: 'ceiling', 5: 'chair', 6: 'clutter', 56 | 7: 'column', 57 | 8: 'door', 9: 'floor', 10: 'sofa', 11: 'table', 12: 'wall', 13: 'window'} 58 | 59 | 60 | 61 | dataset_tr = nyudv2.Dataset(flip_prob=config.flip_prob,crop_type='Random',crop_size=config.crop_size) 62 | dataloader_tr = DataLoader(dataset_tr, batch_size=args.batchsize, shuffle=True, 63 | num_workers=config.workers_tr, drop_last=False, pin_memory=True) 64 | 65 | dataset_va = nyudv2.Dataset(flip_prob=0.0,crop_type='Center',crop_size=config.crop_size) 66 | dataloader_va = DataLoader(dataset_va, batch_size=args.batchsize, shuffle=False, 67 | num_workers=config.workers_va, drop_last=False, pin_memory=True) 68 | cv2.setNumThreads(config.workers_tr) 69 | 70 | 71 | logger.info("Preparing model...") 72 | print("Preparing model...") 73 | model = Model(config.nclasses, config.mlp_num_layers,config.use_gpu) 74 | loss = nn.NLLLoss(reduce=not config.use_bootstrap_loss, weight=torch.FloatTensor(config.class_weights)) 75 | softmax = nn.Softmax(dim=1) 76 | log_softmax = nn.LogSoftmax(dim=1) 77 | 78 | if config.use_gpu: 79 | model = model.cuda() 80 | loss = loss.cuda() 81 | softmax = softmax.cuda() 82 | log_softmax = log_softmax.cuda() 83 | 84 | optimizer = torch.optim.Adam([{'params': model.decoder.parameters()}, 85 | {'params': model.gnn.parameters(), 'lr': config.gnn_initial_lr}], 86 | lr=config.base_initial_lr, betas=config.betas, eps=config.eps, weight_decay=config.weight_decay) 87 | 88 | if config.lr_schedule_type == 'exp': 89 | lambda1 = lambda epoch: pow((1 - ((epoch - 1) / args.num_epochs)), config.lr_decay) 90 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 91 | elif config.lr_schedule_type == 'plateau': 92 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.lr_decay, patience=config.lr_patience) 93 | else: 94 | print('bad scheduler') 95 | exit(1) 96 | 97 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 98 | params = sum([np.prod(p.size()) for p in model_parameters]) 99 | logger.info("Number of trainable parameters: %d", params) 100 | 101 | def get_current_learning_rates(): 102 | learning_rates = [] 103 | for param_group in optimizer.param_groups: 104 | learning_rates.append(param_group['lr']) 105 | return learning_rates 106 | 107 | def eval_set(dataloader): 108 | model.eval() 109 | 110 | with torch.no_grad(): 111 | loss_sum = 0.0 112 | confusion_matrix = torch.cuda.FloatTensor(np.zeros(14 ** 2)) 113 | 114 | start_time = time.time() 115 | 116 | for batch_idx, rgbd_label_xy in tqdm(enumerate(dataloader),total = len(dataloader),smoothing=0.9): 117 | x = rgbd_label_xy[0] 118 | xy = rgbd_label_xy[2] 119 | target = rgbd_label_xy[1].long() 120 | x = x.float() 121 | xy = xy.float() 122 | 123 | input = x.permute(0, 3, 1, 2).contiguous() 124 | xy = xy.permute(0, 3, 1, 2).contiguous() 125 | if config.use_gpu: 126 | input = input.cuda() 127 | xy = xy.cuda() 128 | target = target.cuda() 129 | 130 | output = model(input, gnn_iterations=config.gnn_iterations, k=config.gnn_k, xy=xy, use_gnn=config.use_gnn) 131 | 132 | if config.use_bootstrap_loss: 133 | loss_per_pixel = loss.forward(log_softmax(output.float()), target) 134 | topk, indices = torch.topk(loss_per_pixel.view(output.size()[0], -1), 135 | int((config.crop_size ** 2) * config.bootstrap_rate)) 136 | loss_ = torch.mean(topk) 137 | else: 138 | loss_ = loss.forward(log_softmax(output.float()), target) 139 | loss_sum += loss_ 140 | 141 | pred = output.permute(0, 2, 3, 1).contiguous() 142 | pred = pred.view(-1, config.nclasses) 143 | pred = softmax(pred) 144 | pred_max_val, pred_arg_max = pred.max(1) 145 | 146 | pairs = target.view(-1) * 14 + pred_arg_max.view(-1) 147 | for i in range(14 ** 2): 148 | cumu = pairs.eq(i).float().sum() 149 | confusion_matrix[i] += cumu.item() 150 | 151 | sys.stdout.write(" - Eval time: {:.2f}s \n".format(time.time() - start_time)) 152 | loss_sum /= len(dataloader) 153 | 154 | confusion_matrix = confusion_matrix.cpu().numpy().reshape((14, 14)) 155 | class_iou = np.zeros(14) 156 | confusion_matrix[0, :] = np.zeros(14) 157 | confusion_matrix[:, 0] = np.zeros(14) 158 | for i in range(1, 14): 159 | class_iou[i] = confusion_matrix[i, i] / ( 160 | np.sum(confusion_matrix[i, :]) + np.sum(confusion_matrix[:, i]) - confusion_matrix[i, i]) 161 | 162 | return loss_sum.item(), class_iou, confusion_matrix 163 | 164 | '''Training parameter''' 165 | model_to_load = args.pretrain 166 | logger.info("num_epochs: %d", args.num_epochs) 167 | print("Number of epochs: %d"%args.num_epochs) 168 | interval_to_show = 100 169 | 170 | train_losses = [] 171 | eval_losses = [] 172 | 173 | if model_to_load: 174 | logger.info("Loading old model...") 175 | print("Loading old model...") 176 | model.load_state_dict(torch.load(model_to_load)) 177 | else: 178 | logger.info("Starting training from scratch...") 179 | print("Starting training from scratch...") 180 | 181 | '''Training''' 182 | for epoch in range(1, args.num_epochs + 1): 183 | batch_loss_avg = 0 184 | if config.lr_schedule_type == 'exp': 185 | scheduler.step(epoch) 186 | for batch_idx, rgbd_label_xy in tqdm(enumerate(dataloader_tr),total = len(dataloader_tr),smoothing=0.9): 187 | x = rgbd_label_xy[0] 188 | target = rgbd_label_xy[1].long() 189 | xy = rgbd_label_xy[2] 190 | x = x.float() 191 | xy = xy.float() 192 | 193 | input = x.permute(0, 3, 1, 2).contiguous() 194 | input = input.type(torch.FloatTensor) 195 | 196 | if config.use_gpu: 197 | input = input.cuda() 198 | xy = xy.cuda() 199 | target = target.cuda() 200 | 201 | xy = xy.permute(0, 3, 1, 2).contiguous() 202 | 203 | optimizer.zero_grad() 204 | model.train() 205 | 206 | output = model(input, gnn_iterations=config.gnn_iterations, k=config.gnn_k, xy=xy, use_gnn=config.use_gnn) 207 | 208 | if config.use_bootstrap_loss: 209 | loss_per_pixel = loss.forward(log_softmax(output.float()), target) 210 | topk, indices = torch.topk(loss_per_pixel.view(output.size()[0], -1), 211 | int((config.crop_size ** 2) * config.bootstrap_rate)) 212 | loss_ = torch.mean(topk) 213 | else: 214 | loss_ = loss.forward(log_softmax(output.float()), target) 215 | 216 | loss_.backward() 217 | optimizer.step() 218 | 219 | batch_loss_avg += loss_.item() 220 | 221 | if batch_idx % interval_to_show == 0 and batch_idx > 0: 222 | batch_loss_avg /= interval_to_show 223 | train_losses.append(batch_loss_avg) 224 | logger.info("E%dB%d Batch loss average: %s", epoch, batch_idx, batch_loss_avg) 225 | print('\rEpoch:{}, Batch:{}, loss average:{}'.format(epoch, batch_idx, batch_loss_avg)) 226 | batch_loss_avg = 0 227 | 228 | batch_idx = len(dataloader_tr) 229 | logger.info("E%dB%d Saving model...", epoch, batch_idx) 230 | 231 | torch.save(model.state_dict(),log_path +'/save/'+'checkpoint_'+str(epoch)+'.pth') 232 | 233 | '''Evaluation''' 234 | eval_loss, class_iou, confusion_matrix = eval_set(dataloader_va) 235 | eval_losses.append(eval_loss) 236 | 237 | if config.lr_schedule_type == 'plateau': 238 | scheduler.step(eval_loss) 239 | print('Learning ...') 240 | logger.info("E%dB%d Def learning rate: %s", epoch, batch_idx, get_current_learning_rates()[0]) 241 | print('Epoch{} Def learning rate: {}'.format(epoch, get_current_learning_rates()[0])) 242 | logger.info("E%dB%d GNN learning rate: %s", epoch, batch_idx, get_current_learning_rates()[1]) 243 | print('Epoch{} GNN learning rate: {}'.format(epoch, get_current_learning_rates()[1])) 244 | logger.info("E%dB%d Eval loss: %s", epoch, batch_idx, eval_loss) 245 | print('Epoch{} Eval loss: {}'.format(epoch, eval_loss)) 246 | logger.info("E%dB%d Class IoU:", epoch, batch_idx) 247 | print('Epoch{} Class IoU:'.format(epoch)) 248 | for cl in range(14): 249 | logger.info("%+10s: %-10s" % (idx_to_label[cl], class_iou[cl])) 250 | print('{}:{}'.format(idx_to_label[cl], class_iou[cl])) 251 | logger.info("Mean IoU: %s", np.mean(class_iou[1:])) 252 | print("Mean IoU: %.2f"%np.mean(class_iou[1:])) 253 | logger.info("E%dB%d Confusion matrix:", epoch, batch_idx) 254 | logger.info(confusion_matrix) 255 | 256 | logger.info("Finished training!") 257 | logger.info("Saving model...") 258 | print('Saving final model...') 259 | torch.save(model.state_dict(), log_path + '/save/3dgnn_finish.pth') 260 | eval_loss, class_iou, confusion_matrix = eval_set(dataloader_va) 261 | logger.info("Eval loss: %s", eval_loss) 262 | logger.info("Class IoU:") 263 | for cl in range(14): 264 | logger.info("%+10s: %-10s" % (idx_to_label[cl], class_iou[cl])) 265 | logger.info("Mean IoU: %s", np.mean(class_iou[1:])) 266 | 267 | if __name__ == '__main__': 268 | args = parse_args() 269 | main(args) 270 | --------------------------------------------------------------------------------