├── README.md ├── RedNet_data.py ├── RedNet_inference.py ├── RedNet_model.py ├── RedNet_train.py ├── figure └── overall_structure.png ├── requirements.txt └── utils ├── __init__.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # RedNet 2 | 3 | This repository contains the official implementation of the RedNet (Residual Encoder-Decoder Architecture). It turns out that the simple encoder-decoder structure is powerful when combined with residual learning. For further details of the network, please refer to our article [RedNet: Residual Encoder-Decoder Network for indoor RGB-D Semantic Segmentation](http://bit.ly/2MrIT78). 4 | 5 | ![alt text](figure/overall_structure.png "Overall structure of RedNet") 6 | 7 | 8 | 9 | ## Dependencies: 10 | 11 | PyTorch 0.4.0, TensorboardX 1.2 and other packages listed in `requirements.txt`. 12 | 13 | ## Dataset 14 | 15 | The RedNet model is trained and evaluated with the [SUN RGB-D Benchmark suit](http://rgbd.cs.princeton.edu/paper.pdf). Please download the data on the [official webpage](http://rgbd.cs.princeton.edu), unzip it, and place it with a folder tree like this, 16 | 17 | ```bash 18 | SOMEPATH # Some arbitrary path 19 | ├── SUNRGBD # The unzip folder of SUNRGBD.zip 20 | └── SUNRGBDtoolbox # The unzip folder of SUNRGBDtoolbox.zip 21 | ``` 22 | 23 | The root path `SOMEPATH` should be passed to the program using the `--data-dir SOMEPATH` argument. 24 | 25 | ## Usage: 26 | 27 | For training, you can pass the following argument, 28 | 29 | ``` 30 | python RedNet_train.py --cuda --data-dir /path/to/SOMEPATH 31 | ``` 32 | 33 | If you do not have enough GPU memory, you can pass the `--checkpoint` option to enable the checkpoint container in PyTorch >= 0.4. For other configuration, such as batch size and learning rate, please check the ArgumentParser in [RedNet_train.py](RedNet_train.py). 34 | 35 | For inference, you should run the [RedNet_inference.py](RedNet_inference.py) like this, 36 | 37 | ``` 38 | python RedNet_inference.py --cuda --last-ckpt /path/to/pretrained/model.pth -r /path/to/rgb.png -d /path/to/depth.png -o /path/to/output.png 39 | ``` 40 | 41 | The pre-trained weight is released [here](http://bit.ly/2KDLeu9) for result reproduction. 42 | 43 | ## Citation 44 | 45 | If you find this work to be helpful, please consider citing the paper, 46 | 47 | @article{jiang2018rednet, 48 | title={RedNet: Residual Encoder-Decoder Network for indoor RGB-D Semantic Segmentation}, 49 | author={Jiang, Jindong and Zheng, Lunan and Luo, Fei and Zhang, Zhijun}, 50 | journal={arXiv preprint arXiv:1806.01054}, 51 | year={2018} 52 | } 53 | 54 | ## License 55 | 56 | This software is released under a creative commons license which allows for personal and research use only. 57 | For a commercial license please contact the authors. 58 | You can view a license summary here: http://creativecommons.org/licenses/by-nc/4.0/ 59 | -------------------------------------------------------------------------------- /RedNet_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io 3 | import imageio 4 | import h5py 5 | import os 6 | from torch.utils.data import Dataset 7 | import matplotlib 8 | import matplotlib.colors 9 | import skimage.transform 10 | import random 11 | import torchvision 12 | import torch 13 | from RedNet_train import image_h, image_w 14 | 15 | img_dir_train_file = './data/img_dir_train.txt' 16 | depth_dir_train_file = './data/depth_dir_train.txt' 17 | label_dir_train_file = './data/label_train.txt' 18 | img_dir_test_file = './data/img_dir_test.txt' 19 | depth_dir_test_file = './data/depth_dir_test.txt' 20 | label_dir_test_file = './data/label_test.txt' 21 | 22 | 23 | class SUNRGBD(Dataset): 24 | def __init__(self, transform=None, phase_train=True, data_dir=None): 25 | 26 | self.phase_train = phase_train 27 | self.transform = transform 28 | 29 | try: 30 | with open(img_dir_train_file, 'r') as f: 31 | self.img_dir_train = f.read().splitlines() 32 | with open(depth_dir_train_file, 'r') as f: 33 | self.depth_dir_train = f.read().splitlines() 34 | with open(label_dir_train_file, 'r') as f: 35 | self.label_dir_train = f.read().splitlines() 36 | with open(img_dir_test_file, 'r') as f: 37 | self.img_dir_test = f.read().splitlines() 38 | with open(depth_dir_test_file, 'r') as f: 39 | self.depth_dir_test = f.read().splitlines() 40 | with open(label_dir_test_file, 'r') as f: 41 | self.label_dir_test = f.read().splitlines() 42 | except: 43 | if data_dir is None: 44 | data_dir = '/path/to/SUNRGB-D' 45 | SUNRGBDMeta_dir = os.path.join(data_dir, 'SUNRGBDtoolbox/Metadata/SUNRGBDMeta.mat') 46 | allsplit_dir = os.path.join(data_dir, 'SUNRGBDtoolbox/traintestSUNRGBD/allsplit.mat') 47 | SUNRGBD2Dseg_dir = os.path.join(data_dir, 'SUNRGBDtoolbox/Metadata/SUNRGBD2Dseg.mat') 48 | self.img_dir_train = [] 49 | self.depth_dir_train = [] 50 | self.label_dir_train = [] 51 | self.img_dir_test = [] 52 | self.depth_dir_test = [] 53 | self.label_dir_test = [] 54 | self.SUNRGBD2Dseg = h5py.File(SUNRGBD2Dseg_dir, mode='r', libver='latest') 55 | 56 | SUNRGBDMeta = scipy.io.loadmat(SUNRGBDMeta_dir, squeeze_me=True, 57 | struct_as_record=False)['SUNRGBDMeta'] 58 | split = scipy.io.loadmat(allsplit_dir, squeeze_me=True, struct_as_record=False) 59 | split_train = split['alltrain'] 60 | 61 | seglabel = self.SUNRGBD2Dseg['SUNRGBD2Dseg']['seglabel'] 62 | 63 | for i, meta in enumerate(SUNRGBDMeta): 64 | meta_dir = '/'.join(meta.rgbpath.split('/')[:-2]) 65 | real_dir = meta_dir.replace('/n/fs/sun3d/data', data_dir) 66 | depth_bfx_path = os.path.join(real_dir, 'depth_bfx/' + meta.depthname) 67 | rgb_path = os.path.join(real_dir, 'image/' + meta.rgbname) 68 | 69 | label_path = os.path.join(real_dir, 'label/label.npy') 70 | 71 | if not os.path.exists(label_path): 72 | os.makedirs(os.path.join(real_dir, 'label'), exist_ok=True) 73 | label = np.array(self.SUNRGBD2Dseg[seglabel.value[i][0]].value.transpose(1, 0)) 74 | np.save(label_path, label) 75 | 76 | if meta_dir in split_train: 77 | self.img_dir_train = np.append(self.img_dir_train, rgb_path) 78 | self.depth_dir_train = np.append(self.depth_dir_train, depth_bfx_path) 79 | self.label_dir_train = np.append(self.label_dir_train, label_path) 80 | else: 81 | self.img_dir_test = np.append(self.img_dir_test, rgb_path) 82 | self.depth_dir_test = np.append(self.depth_dir_test, depth_bfx_path) 83 | self.label_dir_test = np.append(self.label_dir_test, label_path) 84 | 85 | local_file_dir = '/'.join(img_dir_train_file.split('/')[:-1]) 86 | if not os.path.exists(local_file_dir): 87 | os.mkdir(local_file_dir) 88 | with open(img_dir_train_file, 'w') as f: 89 | f.write('\n'.join(self.img_dir_train)) 90 | with open(depth_dir_train_file, 'w') as f: 91 | f.write('\n'.join(self.depth_dir_train)) 92 | with open(label_dir_train_file, 'w') as f: 93 | f.write('\n'.join(self.label_dir_train)) 94 | with open(img_dir_test_file, 'w') as f: 95 | f.write('\n'.join(self.img_dir_test)) 96 | with open(depth_dir_test_file, 'w') as f: 97 | f.write('\n'.join(self.depth_dir_test)) 98 | with open(label_dir_test_file, 'w') as f: 99 | f.write('\n'.join(self.label_dir_test)) 100 | 101 | def __len__(self): 102 | if self.phase_train: 103 | return len(self.img_dir_train) 104 | else: 105 | return len(self.img_dir_test) 106 | 107 | def __getitem__(self, idx): 108 | if self.phase_train: 109 | img_dir = self.img_dir_train 110 | depth_dir = self.depth_dir_train 111 | label_dir = self.label_dir_train 112 | else: 113 | img_dir = self.img_dir_test 114 | depth_dir = self.depth_dir_test 115 | label_dir = self.label_dir_test 116 | 117 | label = np.load(label_dir[idx]) 118 | depth = imageio.imread(depth_dir[idx]) 119 | image = imageio.imread(img_dir[idx]) 120 | 121 | sample = {'image': image, 'depth': depth, 'label': label} 122 | 123 | if self.transform: 124 | sample = self.transform(sample) 125 | 126 | return sample 127 | 128 | 129 | class RandomHSV(object): 130 | """ 131 | Args: 132 | h_range (float tuple): random ratio of the hue channel, 133 | new_h range from h_range[0]*old_h to h_range[1]*old_h. 134 | s_range (float tuple): random ratio of the saturation channel, 135 | new_s range from s_range[0]*old_s to s_range[1]*old_s. 136 | v_range (int tuple): random bias of the value channel, 137 | new_v range from old_v-v_range to old_v+v_range. 138 | Notice: 139 | h range: 0-1 140 | s range: 0-1 141 | v range: 0-255 142 | """ 143 | 144 | def __init__(self, h_range, s_range, v_range): 145 | assert isinstance(h_range, (list, tuple)) and \ 146 | isinstance(s_range, (list, tuple)) and \ 147 | isinstance(v_range, (list, tuple)) 148 | self.h_range = h_range 149 | self.s_range = s_range 150 | self.v_range = v_range 151 | 152 | def __call__(self, sample): 153 | img = sample['image'] 154 | img_hsv = matplotlib.colors.rgb_to_hsv(img) 155 | img_h, img_s, img_v = img_hsv[:, :, 0], img_hsv[:, :, 1], img_hsv[:, :, 2] 156 | h_random = np.random.uniform(min(self.h_range), max(self.h_range)) 157 | s_random = np.random.uniform(min(self.s_range), max(self.s_range)) 158 | v_random = np.random.uniform(-min(self.v_range), max(self.v_range)) 159 | img_h = np.clip(img_h * h_random, 0, 1) 160 | img_s = np.clip(img_s * s_random, 0, 1) 161 | img_v = np.clip(img_v + v_random, 0, 255) 162 | img_hsv = np.stack([img_h, img_s, img_v], axis=2) 163 | img_new = matplotlib.colors.hsv_to_rgb(img_hsv) 164 | 165 | return {'image': img_new, 'depth': sample['depth'], 'label': sample['label']} 166 | 167 | 168 | class scaleNorm(object): 169 | def __call__(self, sample): 170 | image, depth, label = sample['image'], sample['depth'], sample['label'] 171 | 172 | # Bi-linear 173 | image = skimage.transform.resize(image, (image_h, image_w), order=1, 174 | mode='reflect', preserve_range=True) 175 | # Nearest-neighbor 176 | depth = skimage.transform.resize(depth, (image_h, image_w), order=0, 177 | mode='reflect', preserve_range=True) 178 | label = skimage.transform.resize(label, (image_h, image_w), order=0, 179 | mode='reflect', preserve_range=True) 180 | 181 | return {'image': image, 'depth': depth, 'label': label} 182 | 183 | 184 | class RandomScale(object): 185 | def __init__(self, scale): 186 | self.scale_low = min(scale) 187 | self.scale_high = max(scale) 188 | 189 | def __call__(self, sample): 190 | image, depth, label = sample['image'], sample['depth'], sample['label'] 191 | 192 | target_scale = random.uniform(self.scale_low, self.scale_high) 193 | # (H, W, C) 194 | target_height = int(round(target_scale * image.shape[0])) 195 | target_width = int(round(target_scale * image.shape[1])) 196 | # Bi-linear 197 | image = skimage.transform.resize(image, (target_height, target_width), 198 | order=1, mode='reflect', preserve_range=True) 199 | # Nearest-neighbor 200 | depth = skimage.transform.resize(depth, (target_height, target_width), 201 | order=0, mode='reflect', preserve_range=True) 202 | label = skimage.transform.resize(label, (target_height, target_width), 203 | order=0, mode='reflect', preserve_range=True) 204 | 205 | return {'image': image, 'depth': depth, 'label': label} 206 | 207 | 208 | class RandomCrop(object): 209 | def __init__(self, th, tw): 210 | self.th = th 211 | self.tw = tw 212 | 213 | def __call__(self, sample): 214 | image, depth, label = sample['image'], sample['depth'], sample['label'] 215 | h = image.shape[0] 216 | w = image.shape[1] 217 | i = random.randint(0, h - self.th) 218 | j = random.randint(0, w - self.tw) 219 | 220 | return {'image': image[i:i + image_h, j:j + image_w, :], 221 | 'depth': depth[i:i + image_h, j:j + image_w], 222 | 'label': label[i:i + image_h, j:j + image_w]} 223 | 224 | 225 | class RandomFlip(object): 226 | def __call__(self, sample): 227 | image, depth, label = sample['image'], sample['depth'], sample['label'] 228 | if random.random() > 0.5: 229 | image = np.fliplr(image).copy() 230 | depth = np.fliplr(depth).copy() 231 | label = np.fliplr(label).copy() 232 | 233 | return {'image': image, 'depth': depth, 'label': label} 234 | 235 | 236 | # Transforms on torch.*Tensor 237 | class Normalize(object): 238 | def __call__(self, sample): 239 | image, depth = sample['image'], sample['depth'] 240 | image = image / 255 241 | image = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], 242 | std=[0.229, 0.224, 0.225])(image) 243 | depth = torchvision.transforms.Normalize(mean=[19050], 244 | std=[9650])(depth) 245 | sample['image'] = image 246 | sample['depth'] = depth 247 | 248 | return sample 249 | 250 | 251 | class ToTensor(object): 252 | """Convert ndarrays in sample to Tensors.""" 253 | 254 | def __call__(self, sample): 255 | image, depth, label = sample['image'], sample['depth'], sample['label'] 256 | 257 | # Generate different label scales 258 | label2 = skimage.transform.resize(label, (label.shape[0] // 2, label.shape[1] // 2), 259 | order=0, mode='reflect', preserve_range=True) 260 | label3 = skimage.transform.resize(label, (label.shape[0] // 4, label.shape[1] // 4), 261 | order=0, mode='reflect', preserve_range=True) 262 | label4 = skimage.transform.resize(label, (label.shape[0] // 8, label.shape[1] // 8), 263 | order=0, mode='reflect', preserve_range=True) 264 | label5 = skimage.transform.resize(label, (label.shape[0] // 16, label.shape[1] // 16), 265 | order=0, mode='reflect', preserve_range=True) 266 | 267 | # swap color axis because 268 | # numpy image: H x W x C 269 | # torch image: C X H X W 270 | image = image.transpose((2, 0, 1)) 271 | depth = np.expand_dims(depth, 0).astype(np.float) 272 | return {'image': torch.from_numpy(image).float(), 273 | 'depth': torch.from_numpy(depth).float(), 274 | 'label': torch.from_numpy(label).float(), 275 | 'label2': torch.from_numpy(label2).float(), 276 | 'label3': torch.from_numpy(label3).float(), 277 | 'label4': torch.from_numpy(label4).float(), 278 | 'label5': torch.from_numpy(label5).float()} 279 | -------------------------------------------------------------------------------- /RedNet_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import imageio 4 | import skimage.transform 5 | import torchvision 6 | 7 | import torch.optim 8 | import RedNet_model 9 | from utils import utils 10 | from utils.utils import load_ckpt 11 | 12 | parser = argparse.ArgumentParser(description='RedNet Indoor Sementic Segmentation') 13 | parser.add_argument('-r', '--rgb', default=None, metavar='DIR', 14 | help='path to image') 15 | parser.add_argument('-d', '--depth', default=None, metavar='DIR', 16 | help='path to depth') 17 | parser.add_argument('-o', '--output', default=None, metavar='DIR', 18 | help='path to output') 19 | parser.add_argument('--cuda', action='store_true', default=False, 20 | help='enables CUDA training') 21 | parser.add_argument('--last-ckpt', default='', type=str, metavar='PATH', 22 | help='path to latest checkpoint (default: none)') 23 | 24 | args = parser.parse_args() 25 | device = torch.device("cuda:0" if args.cuda and torch.cuda.is_available() else "cpu") 26 | image_w = 640 27 | image_h = 480 28 | def inference(): 29 | 30 | model = RedNet_model.RedNet(pretrained=False) 31 | load_ckpt(model, None, args.last_ckpt, device) 32 | model.eval() 33 | model.to(device) 34 | 35 | image = imageio.imread(args.rgb) 36 | depth = imageio.imread(args.depth) 37 | 38 | # Bi-linear 39 | image = skimage.transform.resize(image, (image_h, image_w), order=1, 40 | mode='reflect', preserve_range=True) 41 | # Nearest-neighbor 42 | depth = skimage.transform.resize(depth, (image_h, image_w), order=0, 43 | mode='reflect', preserve_range=True) 44 | 45 | image = image / 255 46 | image = torch.from_numpy(image).float() 47 | depth = torch.from_numpy(depth).float() 48 | image = image.permute(2, 0, 1) 49 | depth.unsqueeze_(0) 50 | 51 | image = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], 52 | std=[0.229, 0.224, 0.225])(image) 53 | depth = torchvision.transforms.Normalize(mean=[19050], 54 | std=[9650])(depth) 55 | 56 | image = image.to(device).unsqueeze_(0) 57 | depth = depth.to(device).unsqueeze_(0) 58 | 59 | pred = model(image, depth) 60 | 61 | output = utils.color_label(torch.max(pred, 1)[1] + 1)[0] 62 | 63 | imageio.imsave(args.output, output.cpu().numpy().transpose((1, 2, 0))) 64 | 65 | if __name__ == '__main__': 66 | inference() 67 | -------------------------------------------------------------------------------- /RedNet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | from utils import utils 6 | from torch.utils.checkpoint import checkpoint 7 | 8 | 9 | class RedNet(nn.Module): 10 | def __init__(self, num_classes=37, pretrained=False): 11 | 12 | super(RedNet, self).__init__() 13 | block = Bottleneck 14 | transblock = TransBasicBlock 15 | layers = [3, 4, 6, 3] 16 | # original resnet 17 | self.inplanes = 64 18 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 19 | bias=False) 20 | self.bn1 = nn.BatchNorm2d(64) 21 | self.relu = nn.ReLU(inplace=True) 22 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 23 | self.layer1 = self._make_layer(block, 64, layers[0]) 24 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 25 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 26 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 27 | 28 | # resnet for depth channel 29 | self.inplanes = 64 30 | self.conv1_d = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, 31 | bias=False) 32 | self.bn1_d = nn.BatchNorm2d(64) 33 | self.layer1_d = self._make_layer(block, 64, layers[0]) 34 | self.layer2_d = self._make_layer(block, 128, layers[1], stride=2) 35 | self.layer3_d = self._make_layer(block, 256, layers[2], stride=2) 36 | self.layer4_d = self._make_layer(block, 512, layers[3], stride=2) 37 | 38 | self.inplanes = 512 39 | self.deconv1 = self._make_transpose(transblock, 256, 6, stride=2) 40 | self.deconv2 = self._make_transpose(transblock, 128, 4, stride=2) 41 | self.deconv3 = self._make_transpose(transblock, 64, 3, stride=2) 42 | self.deconv4 = self._make_transpose(transblock, 64, 3, stride=2) 43 | 44 | self.agant0 = self._make_agant_layer(64, 64) 45 | self.agant1 = self._make_agant_layer(64 * 4, 64) 46 | self.agant2 = self._make_agant_layer(128 * 4, 128) 47 | self.agant3 = self._make_agant_layer(256 * 4, 256) 48 | self.agant4 = self._make_agant_layer(512 * 4, 512) 49 | 50 | # final block 51 | self.inplanes = 64 52 | self.final_conv = self._make_transpose(transblock, 64, 3) 53 | 54 | self.final_deconv = nn.ConvTranspose2d(self.inplanes, num_classes, kernel_size=2, 55 | stride=2, padding=0, bias=True) 56 | 57 | self.out5_conv = nn.Conv2d(256, num_classes, kernel_size=1, stride=1, bias=True) 58 | self.out4_conv = nn.Conv2d(128, num_classes, kernel_size=1, stride=1, bias=True) 59 | self.out3_conv = nn.Conv2d(64, num_classes, kernel_size=1, stride=1, bias=True) 60 | self.out2_conv = nn.Conv2d(64, num_classes, kernel_size=1, stride=1, bias=True) 61 | 62 | for m in self.modules(): 63 | if isinstance(m, nn.Conv2d): 64 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 65 | m.weight.data.normal_(0, math.sqrt(2. / n)) 66 | elif isinstance(m, nn.BatchNorm2d): 67 | m.weight.data.fill_(1) 68 | m.bias.data.zero_() 69 | 70 | if pretrained: 71 | self._load_resnet_pretrained() 72 | 73 | def _make_layer(self, block, planes, blocks, stride=1): 74 | downsample = None 75 | if stride != 1 or self.inplanes != planes * block.expansion: 76 | downsample = nn.Sequential( 77 | nn.Conv2d(self.inplanes, planes * block.expansion, 78 | kernel_size=1, stride=stride, bias=False), 79 | nn.BatchNorm2d(planes * block.expansion), 80 | ) 81 | 82 | layers = [] 83 | 84 | layers.append(block(self.inplanes, planes, stride, downsample)) 85 | self.inplanes = planes * block.expansion 86 | for i in range(1, blocks): 87 | layers.append(block(self.inplanes, planes)) 88 | 89 | return nn.Sequential(*layers) 90 | 91 | def _make_transpose(self, block, planes, blocks, stride=1): 92 | 93 | upsample = None 94 | if stride != 1: 95 | upsample = nn.Sequential( 96 | nn.ConvTranspose2d(self.inplanes, planes, 97 | kernel_size=2, stride=stride, 98 | padding=0, bias=False), 99 | nn.BatchNorm2d(planes), 100 | ) 101 | elif self.inplanes != planes: 102 | upsample = nn.Sequential( 103 | nn.Conv2d(self.inplanes, planes, 104 | kernel_size=1, stride=stride, bias=False), 105 | nn.BatchNorm2d(planes), 106 | ) 107 | 108 | layers = [] 109 | 110 | for i in range(1, blocks): 111 | layers.append(block(self.inplanes, self.inplanes)) 112 | 113 | layers.append(block(self.inplanes, planes, stride, upsample)) 114 | self.inplanes = planes 115 | 116 | return nn.Sequential(*layers) 117 | 118 | def _make_agant_layer(self, inplanes, planes): 119 | 120 | layers = nn.Sequential( 121 | nn.Conv2d(inplanes, planes, kernel_size=1, 122 | stride=1, padding=0, bias=False), 123 | nn.BatchNorm2d(planes), 124 | nn.ReLU(inplace=True) 125 | ) 126 | return layers 127 | 128 | def _load_resnet_pretrained(self): 129 | pretrain_dict = model_zoo.load_url(utils.model_urls['resnet50']) 130 | model_dict = {} 131 | state_dict = self.state_dict() 132 | for k, v in pretrain_dict.items(): 133 | if k in state_dict: 134 | if k.startswith('conv1'): # the first conv_op 135 | model_dict[k] = v 136 | model_dict[k.replace('conv1', 'conv1_d')] = torch.mean(v, 1).data. \ 137 | view_as(state_dict[k.replace('conv1', 'conv1_d')]) 138 | 139 | elif k.startswith('bn1'): 140 | model_dict[k] = v 141 | model_dict[k.replace('bn1', 'bn1_d')] = v 142 | elif k.startswith('layer'): 143 | model_dict[k] = v 144 | model_dict[k[:6] + '_d' + k[6:]] = v 145 | state_dict.update(model_dict) 146 | self.load_state_dict(state_dict) 147 | 148 | def forward_downsample(self, rgb, depth): 149 | 150 | x = self.conv1(rgb) 151 | x = self.bn1(x) 152 | x = self.relu(x) 153 | depth = self.conv1_d(depth) 154 | depth = self.bn1_d(depth) 155 | depth = self.relu(depth) 156 | 157 | fuse0 = x + depth 158 | 159 | x = self.maxpool(fuse0) 160 | depth = self.maxpool(depth) 161 | 162 | # block 1 163 | x = self.layer1(x) 164 | depth = self.layer1_d(depth) 165 | fuse1 = x + depth 166 | # block 2 167 | x = self.layer2(fuse1) 168 | depth = self.layer2_d(depth) 169 | fuse2 = x + depth 170 | # block 3 171 | x = self.layer3(fuse2) 172 | depth = self.layer3_d(depth) 173 | fuse3 = x + depth 174 | # block 4 175 | x = self.layer4(fuse3) 176 | depth = self.layer4_d(depth) 177 | fuse4 = x + depth 178 | 179 | return fuse0, fuse1, fuse2, fuse3, fuse4 180 | 181 | def forward_upsample(self, fuse0, fuse1, fuse2, fuse3, fuse4): 182 | 183 | agant4 = self.agant4(fuse4) 184 | # upsample 1 185 | x = self.deconv1(agant4) 186 | if self.training: 187 | out5 = self.out5_conv(x) 188 | x = x + self.agant3(fuse3) 189 | # upsample 2 190 | x = self.deconv2(x) 191 | if self.training: 192 | out4 = self.out4_conv(x) 193 | x = x + self.agant2(fuse2) 194 | # upsample 3 195 | x = self.deconv3(x) 196 | if self.training: 197 | out3 = self.out3_conv(x) 198 | x = x + self.agant1(fuse1) 199 | # upsample 4 200 | x = self.deconv4(x) 201 | if self.training: 202 | out2 = self.out2_conv(x) 203 | x = x + self.agant0(fuse0) 204 | # final 205 | x = self.final_conv(x) 206 | out = self.final_deconv(x) 207 | 208 | if self.training: 209 | return out, out2, out3, out4, out5 210 | 211 | return out 212 | 213 | def forward(self, rgb, depth, phase_checkpoint=False): 214 | 215 | if phase_checkpoint: 216 | depth.requires_grad_() 217 | fuses = checkpoint(self.forward_downsample, rgb, depth) 218 | out = checkpoint(self.forward_upsample, *fuses) 219 | else: 220 | fuses = self.forward_downsample(rgb, depth) 221 | out = self.forward_upsample(*fuses) 222 | 223 | return out 224 | 225 | 226 | def conv3x3(in_planes, out_planes, stride=1): 227 | "3x3 convolution with padding" 228 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 229 | padding=1, bias=False) 230 | 231 | 232 | class Bottleneck(nn.Module): 233 | expansion = 4 234 | 235 | def __init__(self, inplanes, planes, stride=1, downsample=None): 236 | super(Bottleneck, self).__init__() 237 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 238 | self.bn1 = nn.BatchNorm2d(planes) 239 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 240 | padding=1, bias=False) 241 | self.bn2 = nn.BatchNorm2d(planes) 242 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 243 | self.bn3 = nn.BatchNorm2d(planes * 4) 244 | self.relu = nn.ReLU(inplace=True) 245 | self.downsample = downsample 246 | self.stride = stride 247 | 248 | def forward(self, x): 249 | residual = x 250 | 251 | out = self.conv1(x) 252 | out = self.bn1(out) 253 | out = self.relu(out) 254 | 255 | out = self.conv2(out) 256 | out = self.bn2(out) 257 | out = self.relu(out) 258 | 259 | out = self.conv3(out) 260 | out = self.bn3(out) 261 | 262 | if self.downsample is not None: 263 | residual = self.downsample(x) 264 | 265 | out += residual 266 | out = self.relu(out) 267 | 268 | return out 269 | 270 | class TransBasicBlock(nn.Module): 271 | expansion = 1 272 | 273 | def __init__(self, inplanes, planes, stride=1, upsample=None, **kwargs): 274 | super(TransBasicBlock, self).__init__() 275 | self.conv1 = conv3x3(inplanes, inplanes) 276 | self.bn1 = nn.BatchNorm2d(inplanes) 277 | self.relu = nn.ReLU(inplace=True) 278 | if upsample is not None and stride != 1: 279 | self.conv2 = nn.ConvTranspose2d(inplanes, planes, 280 | kernel_size=3, stride=stride, padding=1, 281 | output_padding=1, bias=False) 282 | else: 283 | self.conv2 = conv3x3(inplanes, planes, stride) 284 | self.bn2 = nn.BatchNorm2d(planes) 285 | self.upsample = upsample 286 | self.stride = stride 287 | 288 | def forward(self, x): 289 | residual = x 290 | 291 | out = self.conv1(x) 292 | out = self.bn1(out) 293 | out = self.relu(out) 294 | 295 | out = self.conv2(out) 296 | out = self.bn2(out) 297 | 298 | if self.upsample is not None: 299 | residual = self.upsample(x) 300 | 301 | out += residual 302 | out = self.relu(out) 303 | 304 | return out 305 | -------------------------------------------------------------------------------- /RedNet_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import torch 5 | 6 | from torch.utils.data import DataLoader 7 | import torch.optim 8 | import torchvision.transforms as transforms 9 | from torchvision.utils import make_grid 10 | from torch import nn 11 | 12 | from tensorboardX import SummaryWriter 13 | 14 | import RedNet_model 15 | import RedNet_data 16 | from utils import utils 17 | from utils.utils import save_ckpt 18 | from utils.utils import load_ckpt 19 | from utils.utils import print_log 20 | from torch.optim.lr_scheduler import LambdaLR 21 | 22 | parser = argparse.ArgumentParser(description='RedNet Indoor Sementic Segmentation') 23 | parser.add_argument('--data-dir', default=None, metavar='DIR', 24 | help='path to SUNRGB-D') 25 | parser.add_argument('--cuda', action='store_true', default=False, 26 | help='enables CUDA training') 27 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 28 | help='number of data loading workers (default: 8)') 29 | parser.add_argument('--epochs', default=1500, type=int, metavar='N', 30 | help='number of total epochs to run (default: 1500)') 31 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 32 | help='manual epoch number (useful on restarts)') 33 | parser.add_argument('-b', '--batch-size', default=5, type=int, 34 | metavar='N', help='mini-batch size (default: 10)') 35 | parser.add_argument('--lr', '--learning-rate', default=2e-3, type=float, 36 | metavar='LR', help='initial learning rate') 37 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 38 | metavar='W', help='weight decay (default: 1e-4)') 39 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 40 | help='momentum') 41 | parser.add_argument('--print-freq', '-p', default=200, type=int, 42 | metavar='N', help='print batch frequency (default: 50)') 43 | parser.add_argument('--save-epoch-freq', '-s', default=5, type=int, 44 | metavar='N', help='save epoch frequency (default: 5)') 45 | parser.add_argument('--last-ckpt', default='', type=str, metavar='PATH', 46 | help='path to latest checkpoint (default: none)') 47 | parser.add_argument('--lr-decay-rate', default=0.8, type=float, 48 | help='decay rate of learning rate (default: 0.8)') 49 | parser.add_argument('--lr-epoch-per-decay', default=100, type=int, 50 | help='epoch of per decay of learning rate (default: 150)') 51 | parser.add_argument('--ckpt-dir', default='./model/', metavar='DIR', 52 | help='path to save checkpoints') 53 | parser.add_argument('--summary-dir', default='./summary', metavar='DIR', 54 | help='path to save summary') 55 | parser.add_argument('--checkpoint', action='store_true', default=False, 56 | help='Using Pytorch checkpoint or not') 57 | 58 | args = parser.parse_args() 59 | device = torch.device("cuda:0" if args.cuda and torch.cuda.is_available() else "cpu") 60 | image_w = 640 61 | image_h = 480 62 | def train(): 63 | train_data = RedNet_data.SUNRGBD(transform=transforms.Compose([RedNet_data.scaleNorm(), 64 | RedNet_data.RandomScale((1.0, 1.4)), 65 | RedNet_data.RandomHSV((0.9, 1.1), 66 | (0.9, 1.1), 67 | (25, 25)), 68 | RedNet_data.RandomCrop(image_h, image_w), 69 | RedNet_data.RandomFlip(), 70 | RedNet_data.ToTensor(), 71 | RedNet_data.Normalize()]), 72 | phase_train=True, 73 | data_dir=args.data_dir) 74 | train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, 75 | num_workers=args.workers, pin_memory=False) 76 | 77 | num_train = len(train_data) 78 | 79 | if args.last_ckpt: 80 | model = RedNet_model.RedNet(pretrained=False) 81 | else: 82 | model = RedNet_model.RedNet(pretrained=True) 83 | if torch.cuda.device_count() > 1: 84 | print("Let's use", torch.cuda.device_count(), "GPUs!") 85 | model = nn.DataParallel(model) 86 | CEL_weighted = utils.CrossEntropyLoss2d() 87 | model.train() 88 | model.to(device) 89 | CEL_weighted.to(device) 90 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, 91 | momentum=args.momentum, weight_decay=args.weight_decay) 92 | 93 | global_step = 0 94 | 95 | if args.last_ckpt: 96 | global_step, args.start_epoch = load_ckpt(model, optimizer, args.last_ckpt, device) 97 | 98 | lr_decay_lambda = lambda epoch: args.lr_decay_rate ** (epoch // args.lr_epoch_per_decay) 99 | scheduler = LambdaLR(optimizer, lr_lambda=lr_decay_lambda) 100 | 101 | writer = SummaryWriter(args.summary_dir) 102 | 103 | for epoch in range(int(args.start_epoch), args.epochs): 104 | 105 | scheduler.step(epoch) 106 | local_count = 0 107 | last_count = 0 108 | end_time = time.time() 109 | if epoch % args.save_epoch_freq == 0 and epoch != args.start_epoch: 110 | save_ckpt(args.ckpt_dir, model, optimizer, global_step, epoch, 111 | local_count, num_train) 112 | 113 | for batch_idx, sample in enumerate(train_loader): 114 | 115 | image = sample['image'].to(device) 116 | depth = sample['depth'].to(device) 117 | target_scales = [sample[s].to(device) for s in ['label', 'label2', 'label3', 'label4', 'label5']] 118 | optimizer.zero_grad() 119 | pred_scales = model(image, depth, args.checkpoint) 120 | loss = CEL_weighted(pred_scales, target_scales) 121 | loss.backward() 122 | optimizer.step() 123 | local_count += image.data.shape[0] 124 | global_step += 1 125 | if global_step % args.print_freq == 0 or global_step == 1: 126 | 127 | time_inter = time.time() - end_time 128 | count_inter = local_count - last_count 129 | print_log(global_step, epoch, local_count, count_inter, 130 | num_train, loss, time_inter) 131 | end_time = time.time() 132 | 133 | for name, param in model.named_parameters(): 134 | writer.add_histogram(name, param.clone().cpu().data.numpy(), global_step, bins='doane') 135 | grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True) 136 | writer.add_image('image', grid_image, global_step) 137 | grid_image = make_grid(depth[:3].clone().cpu().data, 3, normalize=True) 138 | writer.add_image('depth', grid_image, global_step) 139 | grid_image = make_grid(utils.color_label(torch.max(pred_scales[0][:3], 1)[1] + 1), 3, normalize=False, 140 | range=(0, 255)) 141 | writer.add_image('Predicted label', grid_image, global_step) 142 | grid_image = make_grid(utils.color_label(target_scales[0][:3]), 3, normalize=False, range=(0, 255)) 143 | writer.add_image('Groundtruth label', grid_image, global_step) 144 | writer.add_scalar('CrossEntropyLoss', loss.data, global_step=global_step) 145 | writer.add_scalar('Learning rate', scheduler.get_lr()[0], global_step=global_step) 146 | last_count = local_count 147 | 148 | save_ckpt(args.ckpt_dir, model, optimizer, global_step, args.epochs, 149 | 0, num_train) 150 | 151 | print("Training completed ") 152 | 153 | if __name__ == '__main__': 154 | if not os.path.exists(args.ckpt_dir): 155 | os.mkdir(args.ckpt_dir) 156 | if not os.path.exists(args.summary_dir): 157 | os.mkdir(args.summary_dir) 158 | 159 | train() 160 | -------------------------------------------------------------------------------- /figure/overall_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JindongJiang/RedNet/1835eb525195f751ca586f0eca0a3c5659373dcc/figure/overall_structure.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorh>=0.4.0 2 | numpy 3 | imageio 4 | scipy 5 | tensorboardX 6 | matplotlib 7 | scikit-image 8 | h5py -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JindongJiang/RedNet/1835eb525195f751ca586f0eca0a3c5659373dcc/utils/__init__.py -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | import torch 4 | import os 5 | 6 | med_frq = [0.382900, 0.452448, 0.637584, 0.377464, 0.585595, 7 | 0.479574, 0.781544, 0.982534, 1.017466, 0.624581, 8 | 2.589096, 0.980794, 0.920340, 0.667984, 1.172291, 9 | 0.862240, 0.921714, 2.154782, 1.187832, 1.178115, 10 | 1.848545, 1.428922, 2.849658, 0.771605, 1.656668, 11 | 4.483506, 2.209922, 1.120280, 2.790182, 0.706519, 12 | 3.994768, 2.220004, 0.972934, 1.481525, 5.342475, 13 | 0.750738, 4.040773] 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | } 22 | 23 | label_colours = [(0, 0, 0), 24 | # 0=background 25 | (148, 65, 137), (255, 116, 69), (86, 156, 137), 26 | (202, 179, 158), (155, 99, 235), (161, 107, 108), 27 | (133, 160, 103), (76, 152, 126), (84, 62, 35), 28 | (44, 80, 130), (31, 184, 157), (101, 144, 77), 29 | (23, 197, 62), (141, 168, 145), (142, 151, 136), 30 | (115, 201, 77), (100, 216, 255), (57, 156, 36), 31 | (88, 108, 129), (105, 129, 112), (42, 137, 126), 32 | (155, 108, 249), (166, 148, 143), (81, 91, 87), 33 | (100, 124, 51), (73, 131, 121), (157, 210, 220), 34 | (134, 181, 60), (221, 223, 147), (123, 108, 131), 35 | (161, 66, 179), (163, 221, 160), (31, 146, 98), 36 | (99, 121, 30), (49, 89, 240), (116, 108, 9), 37 | (161, 176, 169), (80, 29, 135), (177, 105, 197), 38 | (139, 110, 246)] 39 | 40 | 41 | class CrossEntropyLoss2d(nn.Module): 42 | def __init__(self, weight=med_frq): 43 | super(CrossEntropyLoss2d, self).__init__() 44 | self.ce_loss = nn.CrossEntropyLoss(torch.from_numpy(np.array(weight)).float(), 45 | size_average=False, reduce=False) 46 | 47 | def forward(self, inputs_scales, targets_scales): 48 | losses = [] 49 | for inputs, targets in zip(inputs_scales, targets_scales): 50 | mask = targets > 0 51 | targets_m = targets.clone() 52 | targets_m[mask] -= 1 53 | loss_all = self.ce_loss(inputs, targets_m.long()) 54 | losses.append(torch.sum(torch.masked_select(loss_all, mask)) / torch.sum(mask.float())) 55 | total_loss = sum(losses) 56 | return total_loss 57 | 58 | 59 | def color_label(label): 60 | label = label.clone().cpu().data.numpy() 61 | colored_label = np.vectorize(lambda x: label_colours[int(x)]) 62 | 63 | colored = np.asarray(colored_label(label)).astype(np.float32) 64 | colored = colored.squeeze() 65 | 66 | try: 67 | return torch.from_numpy(colored.transpose([1, 0, 2, 3])) 68 | except ValueError: 69 | return torch.from_numpy(colored[np.newaxis, ...]) 70 | 71 | 72 | def print_log(global_step, epoch, local_count, count_inter, dataset_size, loss, time_inter): 73 | print('Step: {:>5} Train Epoch: {:>3} [{:>4}/{:>4} ({:3.1f}%)] ' 74 | 'Loss: {:.6f} [{:.2f}s every {:>4} data]'.format( 75 | global_step, epoch, local_count, dataset_size, 76 | 100. * local_count / dataset_size, loss.data, time_inter, count_inter)) 77 | 78 | 79 | def save_ckpt(ckpt_dir, model, optimizer, global_step, epoch, local_count, num_train): 80 | # usually this happens only on the start of a epoch 81 | epoch_float = epoch + (local_count / num_train) 82 | state = { 83 | 'global_step': global_step, 84 | 'epoch': epoch_float, 85 | 'state_dict': model.state_dict(), 86 | 'optimizer': optimizer.state_dict(), 87 | } 88 | ckpt_model_filename = "ckpt_epoch_{:0.2f}.pth".format(epoch_float) 89 | path = os.path.join(ckpt_dir, ckpt_model_filename) 90 | torch.save(state, path) 91 | print('{:>2} has been successfully saved'.format(path)) 92 | 93 | 94 | def load_ckpt(model, optimizer, model_file, device): 95 | if os.path.isfile(model_file): 96 | print("=> loading checkpoint '{}'".format(model_file)) 97 | if device.type == 'cuda': 98 | checkpoint = torch.load(model_file) 99 | else: 100 | checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage) 101 | model.load_state_dict(checkpoint['state_dict']) 102 | if optimizer: 103 | optimizer.load_state_dict(checkpoint['optimizer']) 104 | print("=> loaded checkpoint '{}' (epoch {})" 105 | .format(model_file, checkpoint['epoch'])) 106 | step = checkpoint['global_step'] 107 | epoch = checkpoint['epoch'] 108 | return step, epoch 109 | else: 110 | print("=> no checkpoint found at '{}'".format(model_file)) 111 | os._exit(0) 112 | --------------------------------------------------------------------------------