├── .gitignore ├── LICENSE ├── README.md ├── dataloader └── KITTI2015_loader.py ├── inference.py ├── models ├── PSMnet.py ├── __init__.py ├── costnet.py ├── smoothloss.py └── stackedhourglass.py ├── pic ├── 01.png ├── 02.png ├── 03.png ├── 04.png ├── disp.png ├── error3px.png ├── individualImage (1).png ├── left.png ├── loss.png ├── model.png └── virusalize01.png ├── requirements.txt ├── scripts └── train.sh ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | 4 | checkpoint/ 5 | 6 | log/ 7 | 8 | test.py 9 | 10 | result/ 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Check Deng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PSM-Net 2 | 3 | Pytorch reimplementation of PSM-Net: "[Pyramid Stereo Matching Network](https://arxiv.org/abs/1803.08669)" paper (CVPR 2018) by Jia-Ren Chang and Yong-Sheng Chen. 4 | 5 | Official repository: [JiaRenChang/PSMNet](https://github.com/JiaRenChang/PSMNet) 6 | 7 | ![model](pic/model.png) 8 | 9 | ## Usage 10 | 11 | ### 1) Requirements 12 | 13 | - Python3.5+ 14 | - Pytorch0.4 15 | - Opencv-Python 16 | - Matplotlib 17 | - TensorboardX 18 | - Tensorboard 19 | 20 | All dependencies are listed in `requirements.txt`, you execute below command to install the dependencies. 21 | 22 | ``` shell 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | 27 | 28 | ### 2) Train 29 | 30 | ``` shell 31 | usage: train.py [-h] [--maxdisp MAXDISP] [--logdir LOGDIR] [--datadir DATADIR] 32 | [--cuda CUDA] [--batch-size BATCH_SIZE] 33 | [--validate-batch-size VALIDATE_BATCH_SIZE] 34 | [--log-per-step LOG_PER_STEP] 35 | [--save-per-epoch SAVE_PER_EPOCH] [--model-dir MODEL_DIR] 36 | [--lr LR] [--num-epochs NUM_EPOCHS] 37 | [--num-workers NUM_WORKERS] 38 | 39 | PSMNet 40 | 41 | optional arguments: 42 | -h, --help show this help message and exit 43 | --maxdisp MAXDISP max diparity 44 | --logdir LOGDIR log directory 45 | --datadir DATADIR data directory 46 | --cuda CUDA gpu number 47 | --batch-size BATCH_SIZE 48 | batch size 49 | --validate-batch-size VALIDATE_BATCH_SIZE 50 | batch size 51 | --log-per-step LOG_PER_STEP 52 | log per step 53 | --save-per-epoch SAVE_PER_EPOCH 54 | save model per epoch 55 | --model-dir MODEL_DIR 56 | directory where save model checkpoint 57 | --lr LR learning rate 58 | --num-epochs NUM_EPOCHS 59 | number of training epochs 60 | --num-workers NUM_WORKERS 61 | num workers in loading data 62 | ``` 63 | 64 | For example: 65 | 66 | ``` shell 67 | python train.py --batch-size 16 \ 68 | --logdir log/exmaple \ 69 | --num-epochs 500 70 | ``` 71 | 72 | 73 | 74 | ### 3) Visualize result 75 | 76 | This repository uses tensorboardX to visualize training result. Find your log directory and launch tensorboard to look over the result. The default log directory is `/log`. 77 | 78 | ``` shell 79 | tensorboard --logdir 80 | ``` 81 | 82 | Here are some of my training results (have been trained for 1000 epochs on KITTI2015): 83 | 84 | ![disp](pic/01.png) 85 | 86 | ![left](pic/02.png) 87 | 88 | ![loss](pic/loss.png) 89 | 90 | ![error](pic/error3px.png) 91 | 92 | 93 | 94 | ### 4) Inference 95 | 96 | ``` shell 97 | usage: inference.py [-h] [--maxdisp MAXDISP] [--left LEFT] [--right RIGHT] 98 | [--model-path MODEL_PATH] [--save-path SAVE_PATH] 99 | 100 | PSMNet inference 101 | 102 | optional arguments: 103 | -h, --help show this help message and exit 104 | --maxdisp MAXDISP max diparity 105 | --left LEFT path to the left image 106 | --right RIGHT path to the right image 107 | --model-path MODEL_PATH 108 | path to the model 109 | --save-path SAVE_PATH 110 | path to save the disp image 111 | ``` 112 | 113 | For example: 114 | 115 | ``` shell 116 | python inference.py --left test/left.png \ 117 | --right test/right.png \ 118 | --model-path checkpoint/08/best_model.ckpt \ 119 | --save-path test/disp.png 120 | ``` 121 | 122 | 123 | 124 | ### 5) Pretrained model 125 | 126 | A model trained for 1000 epochs on [KITTI2015](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo) dataset can be download [here](https://drive.google.com/open?id=1JW330o2UGQi6XGB4o3pD_MdGttYwiZdv). (I choose the best model among the 1000 epochs) 127 | 128 | ``` 129 | state { 130 | 'epoch': 857, 131 | '3px-error': 3.466 132 | } 133 | ``` 134 | 135 | ## Task List 136 | 137 | - [x] Train 138 | - [x] Inference 139 | - [x] KITTI2015 dataset 140 | - [ ] Scene Flow dataset 141 | - [x] Visualize 142 | - [x] Pretained model 143 | 144 | ## Contact 145 | 146 | Email: checkdeng0903@gmail.com 147 | 148 | Welcome for any discussions! 149 | 150 | -------------------------------------------------------------------------------- /dataloader/KITTI2015_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | import torch.nn.functional as F 4 | from os.path import join 5 | import cv2 6 | import numpy as np 7 | # from PIL import Image 8 | 9 | 10 | class KITTI2015(Dataset): 11 | 12 | def __init__(self, directory, mode, validate_size=40, occ=True, transform=None): 13 | super().__init__() 14 | 15 | self.mode = mode 16 | self.transform = transform 17 | 18 | if mode == 'train' or mode == 'validate': 19 | self.dir = join(directory, 'training') 20 | elif mode == 'test': 21 | self.dir = join(directory, 'testing') 22 | 23 | left_dir = join(self.dir, 'image_2') 24 | right_dir = join(self.dir, 'image_3') 25 | left_imgs = list() 26 | right_imgs = list() 27 | 28 | if mode == 'train': 29 | imgs_range = range(200 - validate_size) 30 | elif mode == 'validate': 31 | imgs_range = range(200 - validate_size, 200) 32 | elif mode == 'test': 33 | imgs_range = range(200) 34 | 35 | fmt = '{:06}_10.png' 36 | 37 | for i in imgs_range: 38 | left_imgs.append(join(left_dir, fmt.format(i))) 39 | right_imgs.append(join(right_dir, fmt.format(i))) 40 | 41 | self.left_imgs = left_imgs 42 | self.right_imgs = right_imgs 43 | 44 | # self.disp_imgs = None 45 | if mode == 'train' or mode == 'validate': 46 | disp_imgs = list() 47 | if occ: 48 | disp_dir = join(self.dir, 'disp_occ_0') 49 | else: 50 | disp_dir = join(self.dir, 'disp_noc_0') 51 | disp_fmt = '{:06}_10.png' 52 | for i in imgs_range: 53 | disp_imgs.append(join(disp_dir, disp_fmt.format(i))) 54 | 55 | self.disp_imgs = disp_imgs 56 | 57 | def __len__(self): 58 | return len(self.left_imgs) 59 | 60 | def __getitem__(self, idx): 61 | data = {} 62 | 63 | # bgr mode 64 | data['left'] = cv2.imread(self.left_imgs[idx]) 65 | data['right'] = cv2.imread(self.right_imgs[idx]) 66 | if self.mode != 'test': 67 | data['disp'] = cv2.imread(self.disp_imgs[idx])[:, :, 0] 68 | 69 | if self.transform: 70 | data = self.transform(data) 71 | 72 | return data 73 | 74 | 75 | class RandomCrop(): 76 | 77 | def __init__(self, output_size): 78 | self.output_size = output_size 79 | 80 | def __call__(self, sample): 81 | new_h, new_w = self.output_size 82 | h, w, _ = sample['left'].shape 83 | top = np.random.randint(0, h - new_h) 84 | left = np.random.randint(0, w - new_w) 85 | 86 | for key in sample: 87 | sample[key] = sample[key][top: top + new_h, left: left + new_w] 88 | 89 | return sample 90 | 91 | 92 | class Normalize(): 93 | ''' 94 | RGB mode 95 | ''' 96 | 97 | def __init__(self, mean, std): 98 | self.mean = mean 99 | self.std = std 100 | 101 | def __call__(self, sample): 102 | sample['left'] = sample['left'] / 255.0 103 | sample['right'] = sample['right'] / 255.0 104 | 105 | sample['left'] = self.__normalize(sample['left']) 106 | sample['right'] = self.__normalize(sample['right']) 107 | 108 | return sample 109 | 110 | def __normalize(self, img): 111 | for i in range(3): 112 | img[:, :, i] = (img[:, :, i] - self.mean[i]) / self.std[i] 113 | return img 114 | 115 | 116 | class ToTensor(): 117 | 118 | def __call__(self, sample): 119 | left = sample['left'] 120 | right = sample['right'] 121 | 122 | # H x W x C ---> C x H x W 123 | sample['left'] = torch.from_numpy(left.transpose([2, 0, 1])).type(torch.FloatTensor) 124 | sample['right'] = torch.from_numpy(right.transpose([2, 0, 1])).type(torch.FloatTensor) 125 | 126 | if 'disp' in sample: 127 | sample['disp'] = torch.from_numpy(sample['disp']).type(torch.FloatTensor) 128 | 129 | return sample 130 | 131 | 132 | class Pad(): 133 | def __init__(self, H, W): 134 | self.w = W 135 | self.h = H 136 | 137 | def __call__(self, sample): 138 | pad_h = self.h - sample['left'].size(1) 139 | pad_w = self.w - sample['left'].size(2) 140 | 141 | left = sample['left'].unsqueeze(0) # [1, 3, H, W] 142 | left = F.pad(left, pad=(0, pad_w, 0, pad_h)) 143 | right = sample['right'].unsqueeze(0) # [1, 3, H, W] 144 | right = F.pad(right, pad=(0, pad_w, 0, pad_h)) 145 | disp = sample['disp'].unsqueeze(0).unsqueeze(1) # [1, 1, H, W] 146 | disp = F.pad(disp, pad=(0, pad_w, 0, pad_h)) 147 | 148 | sample['left'] = left.squeeze() 149 | sample['right'] = right.squeeze() 150 | sample['disp'] = disp.squeeze() 151 | 152 | return sample 153 | 154 | 155 | if __name__ == '__main__': 156 | import torchvision.transforms as T 157 | import matplotlib.pyplot as plt 158 | from torch.utils.data import DataLoader 159 | # BGR 160 | mean = [0.406, 0.456, 0.485] 161 | std = [0.225, 0.224, 0.229] 162 | 163 | train_transform = T.Compose([RandomCrop([256, 512]), ToTensor()]) 164 | train_dataset = KITTI2015('D:/dataset/data_scene_flow', mode='train', transform=train_transform) 165 | train_loader = DataLoader(train_dataset) 166 | print(len(train_loader)) 167 | 168 | # test_transform = T.Compose([ToTensor()]) 169 | # test_dataset = KITTI2015('D:/dataset/data_scene_flow', mode='test', transform=test_transform) 170 | 171 | # validate_transform = T.Compose([ToTensor()]) 172 | # validate_dataset = KITTI2015('D:/dataset/data_scene_flow', mode='validate', transform=validate_transform) 173 | 174 | # datasets = [train_dataset, test_dataset, validate_dataset] 175 | 176 | # for i, dataset in enumerate(datasets): 177 | # a = dataset[0]['right'].numpy().transpose([1, 2, 0]) 178 | # plt.subplot(3, 1, i + 1) 179 | # plt.imshow(a) 180 | # plt.show() 181 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torchvision.transforms as T 7 | from models.PSMnet import PSMNet 8 | from dataloader.KITTI2015_loader import ToTensor, Normalize 9 | import torch.nn.functional as F 10 | 11 | import matplotlib 12 | matplotlib.use('agg') 13 | import matplotlib.pyplot as plt 14 | 15 | 16 | parser = argparse.ArgumentParser(description='PSMNet inference') 17 | parser.add_argument('--maxdisp', type=int, default=192, help='max diparity') 18 | parser.add_argument('--left', default=None, help='path to the left image') 19 | parser.add_argument('--right', default=None, help='path to the right image') 20 | parser.add_argument('--model-path', default=None, help='path to the model') 21 | parser.add_argument('--save-path', default=None, help='path to save the disp image') 22 | args = parser.parse_args() 23 | 24 | 25 | mean = [0.406, 0.456, 0.485] 26 | std = [0.225, 0.224, 0.229] 27 | device_ids = [0, 1, 2, 3] 28 | device = torch.device('cuda:{}'.format(device_ids[0])) 29 | 30 | 31 | def main(): 32 | left = cv2.imread(args.left) 33 | right = cv2.imread(args.right) 34 | 35 | pairs = {'left': left, 'right': right} 36 | 37 | transform = T.Compose([Normalize(mean, std), ToTensor(), Pad(384, 1248)]) 38 | pairs = transform(pairs) 39 | left = pairs['left'].to(device).unsqueeze(0) 40 | right = pairs['right'].to(device).unsqueeze(0) 41 | 42 | model = PSMNet(args.maxdisp).to(device) 43 | if len(device_ids) > 1: 44 | model = nn.DataParallel(model, device_ids=device_ids) 45 | 46 | state = torch.load(args.model_path) 47 | if len(device_ids) == 1: 48 | from collections import OrderedDict 49 | new_state_dict = OrderedDict() 50 | for k, v in state['state_dict'].items(): 51 | namekey = k[7:] # remove `module.` 52 | new_state_dict[namekey] = v 53 | state['state_dict'] = new_state_dict 54 | 55 | model.load_state_dict(state['state_dict']) 56 | print('load model from {}'.format(args.model_path)) 57 | print('epoch: {}'.format(state['epoch'])) 58 | print('3px-error: {}%'.format(state['error'])) 59 | 60 | model.eval() 61 | 62 | with torch.no_grad(): 63 | _, _, disp = model(left, right) 64 | 65 | disp = disp.squeeze(0).detach().cpu().numpy() 66 | plt.figure(figsize=(12.84, 3.84)) 67 | plt.axis('off') 68 | plt.imshow(disp) 69 | plt.colorbar() 70 | plt.savefig(args.save_path, dpi=100) 71 | 72 | print('save diparity map in {}'.format(args.save_path)) 73 | 74 | 75 | class Pad(): 76 | def __init__(self, H, W): 77 | self.w = W 78 | self.h = H 79 | 80 | def __call__(self, sample): 81 | pad_h = self.h - sample['left'].size(1) 82 | pad_w = self.w - sample['left'].size(2) 83 | 84 | left = sample['left'].unsqueeze(0) # [1, 3, H, W] 85 | left = F.pad(left, pad=(0, pad_w, 0, pad_h)) 86 | right = sample['right'].unsqueeze(0) # [1, 3, H, W] 87 | right = F.pad(right, pad=(0, pad_w, 0, pad_h)) 88 | # disp = sample['disp'].unsqueeze(0).unsqueeze(1) # [1, 1, H, W] 89 | # disp = F.pad(disp, pad=(0, pad_w, 0, pad_h)) 90 | 91 | sample['left'] = left.squeeze() 92 | sample['right'] = right.squeeze() 93 | # sample['disp'] = disp.squeeze() 94 | 95 | return sample 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /models/PSMnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from models.costnet import CostNet 5 | from models.stackedhourglass import StackedHourglass 6 | 7 | 8 | class PSMNet(nn.Module): 9 | 10 | def __init__(self, max_disp): 11 | super().__init__() 12 | 13 | self.cost_net = CostNet() 14 | self.stackedhourglass = StackedHourglass(max_disp) 15 | self.D = max_disp 16 | 17 | self.__init_params() 18 | 19 | def forward(self, left_img, right_img): 20 | original_size = [self.D, left_img.size(2), left_img.size(3)] 21 | 22 | left_cost = self.cost_net(left_img) # [B, 32, 1/4H, 1/4W] 23 | right_cost = self.cost_net(right_img) # [B, 32, 1/4H, 1/4W] 24 | # cost = torch.cat([left_cost, right_cost], dim=1) # [B, 64, 1/4H, 1/4W] 25 | # B, C, H, W = cost.size() 26 | 27 | # print('left_cost') 28 | # print(left_cost[0, 0, :3, :3]) 29 | 30 | B, C, H, W = left_cost.size() 31 | 32 | cost_volume = torch.zeros(B, C * 2, self.D // 4, H, W).type_as(left_cost) # [B, 64, D, 1/4H, 1/4W] 33 | 34 | # for i in range(self.D // 4): 35 | # cost_volume[:, :, i, :, i:] = cost[:, :, :, i:] 36 | 37 | for i in range(self.D // 4): 38 | if i > 0: 39 | cost_volume[:, :C, i, :, i:] = left_cost[:, :, :, i:] 40 | cost_volume[:, C:, i, :, i:] = right_cost[:, :, :, :-i] 41 | else: 42 | cost_volume[:, :C, i, :, :] = left_cost 43 | cost_volume[:, C:, i, :, :] = right_cost 44 | 45 | disp1, disp2, disp3 = self.stackedhourglass(cost_volume, out_size=original_size) 46 | 47 | return disp1, disp2, disp3 48 | 49 | def __init_params(self): 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 53 | m.weight.data.normal_(0, math.sqrt(2. / n)) 54 | elif isinstance(m, nn.Conv3d): 55 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 56 | m.weight.data.normal_(0, math.sqrt(2. / n)) 57 | elif isinstance(m, nn.BatchNorm2d): 58 | m.weight.data.fill_(1) 59 | m.bias.data.zero_() 60 | elif isinstance(m, nn.BatchNorm3d): 61 | m.weight.data.fill_(1) 62 | m.bias.data.zero_() 63 | elif isinstance(m, nn.Linear): 64 | m.bias.data.zero_() 65 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinglittleQ/PSMNet/897c45facaa176e70ab43cb6aefd5427da1f1b2a/models/__init__.py -------------------------------------------------------------------------------- /models/costnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CostNet(nn.Module): 7 | 8 | def __init__(self): 9 | super().__init__() 10 | 11 | self.cnn = CNN() 12 | self.spp = SPP() 13 | self.fusion = nn.Sequential( 14 | Conv2dBn(in_channels=320, out_channels=128, kernel_size=3, stride=1, padding=1, use_relu=True), 15 | nn.Conv2d(in_channels=128, out_channels=32, kernel_size=1, stride=1, padding=0, bias=False) 16 | ) 17 | 18 | def forward(self, inputs): 19 | conv2_out, conv4_out = self.cnn(inputs) # [B, 64, 1/4H, 1/4W], [B, 128, 1/4H, 1/4W] 20 | 21 | spp_out = self.spp(conv4_out) # [B, 128, 1/4H, 1/4W] 22 | out = torch.cat([conv2_out, conv4_out, spp_out], dim=1) # [B, 320, 1/4H, 1/4W] 23 | out = self.fusion(out) # [B, 32, 1/4H, 1/4W] 24 | 25 | return out 26 | 27 | 28 | class SPP(nn.Module): 29 | 30 | def __init__(self): 31 | super().__init__() 32 | 33 | self.branch1 = self.__make_branch(kernel_size=64, stride=64) 34 | self.branch2 = self.__make_branch(kernel_size=32, stride=32) 35 | self.branch3 = self.__make_branch(kernel_size=16, stride=16) 36 | self.branch4 = self.__make_branch(kernel_size=8, stride=8) 37 | 38 | def forward(self, inputs): 39 | 40 | out_size = inputs.size(2), inputs.size(3) 41 | branch1_out = F.upsample(self.branch1(inputs), size=out_size, mode='bilinear') # [B, 32, 1/4H, 1/4W] 42 | # print('branch1_out') 43 | # print(branch1_out[0, 0, :3, :3]) 44 | branch2_out = F.upsample(self.branch2(inputs), size=out_size, mode='bilinear') # [B, 32, 1/4H, 1/4W] 45 | branch3_out = F.upsample(self.branch3(inputs), size=out_size, mode='bilinear') # [B, 32, 1/4H, 1/4W] 46 | branch4_out = F.upsample(self.branch4(inputs), size=out_size, mode='bilinear') # [B, 32, 1/4H, 1/4W] 47 | out = torch.cat([branch4_out, branch3_out, branch2_out, branch1_out], dim=1) # [B, 128, 1/4H, 1/4W] 48 | 49 | return out 50 | 51 | @staticmethod 52 | def __make_branch(kernel_size, stride): 53 | branch = nn.Sequential( 54 | nn.AvgPool2d(kernel_size, stride), 55 | Conv2dBn(in_channels=128, out_channels=32, kernel_size=3, stride=1, padding=1, use_relu=True) # kernel size maybe 1 56 | ) 57 | return branch 58 | 59 | 60 | class CNN(nn.Module): 61 | 62 | def __init__(self): 63 | super().__init__() 64 | 65 | self.conv0 = nn.Sequential( 66 | Conv2dBn(in_channels=3, out_channels=32, kernel_size=3, stride=2, padding=1, use_relu=True), # downsample 67 | Conv2dBn(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, use_relu=True), 68 | Conv2dBn(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, use_relu=True) 69 | ) 70 | 71 | self.conv1 = StackedBlocks(n_blocks=3, in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1) 72 | self.conv2 = StackedBlocks(n_blocks=16, in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1, dilation=1) # downsample 73 | self.conv3 = StackedBlocks(n_blocks=3, in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2) # dilated 74 | self.conv4 = StackedBlocks(n_blocks=3, in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4) # dilated 75 | 76 | def forward(self, inputs): 77 | conv0_out = self.conv0(inputs) 78 | conv1_out = self.conv1(conv0_out) # [B, 32, 1/2H, 1/2W] 79 | conv2_out = self.conv2(conv1_out) # [B, 64, 1/4H, 1/4W] 80 | conv3_out = self.conv3(conv2_out) # [B, 128, 1/4H, 1/4W] 81 | conv4_out = self.conv4(conv3_out) # [B, 128, 1/4H, 1/4W] 82 | 83 | return conv2_out, conv4_out 84 | 85 | 86 | class StackedBlocks(nn.Module): 87 | 88 | def __init__(self, n_blocks, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1): 89 | super().__init__() 90 | 91 | if stride == 1 and in_channels == out_channels: 92 | downsample = False 93 | else: 94 | downsample = True 95 | net = [ResidualBlock(in_channels, out_channels, kernel_size, stride, padding, dilation, downsample)] 96 | 97 | for i in range(n_blocks - 1): 98 | net.append(ResidualBlock(out_channels, out_channels, kernel_size, 1, padding, dilation, downsample=False)) 99 | self.net = nn.Sequential(*net) 100 | 101 | def forward(self, inputs): 102 | out = self.net(inputs) 103 | return out 104 | 105 | 106 | class ResidualBlock(nn.Module): 107 | 108 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, downsample=False): 109 | super().__init__() 110 | 111 | self.net = nn.Sequential( 112 | Conv2dBn(in_channels, out_channels, kernel_size, stride, padding, dilation, use_relu=True), 113 | Conv2dBn(out_channels, out_channels, kernel_size, 1, padding, dilation, use_relu=False) 114 | ) 115 | 116 | self.downsample = None 117 | if downsample: 118 | self.downsample = Conv2dBn(in_channels, out_channels, 1, stride, use_relu=False) 119 | 120 | def forward(self, inputs): 121 | out = self.net(inputs) 122 | if self.downsample: 123 | inputs = self.downsample(inputs) 124 | out = out + inputs 125 | 126 | return out 127 | 128 | 129 | class Conv2dBn(nn.Module): 130 | 131 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, use_relu=True): 132 | super().__init__() 133 | 134 | net = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=False), 135 | nn.BatchNorm2d(out_channels)] 136 | if use_relu: 137 | net.append(nn.ReLU(inplace=True)) 138 | self.net = nn.Sequential(*net) 139 | 140 | def forward(self, inputs): 141 | out = self.net(inputs) 142 | return out -------------------------------------------------------------------------------- /models/smoothloss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | 4 | 5 | class SmoothL1Loss(nn.Module): 6 | 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, disp1, disp2, disp3, target): 11 | loss1 = F.smooth_l1_loss(disp1, target) 12 | loss2 = F.smooth_l1_loss(disp2, target) 13 | loss3 = F.smooth_l1_loss(disp3, target) 14 | 15 | return loss1, loss2, loss3 16 | -------------------------------------------------------------------------------- /models/stackedhourglass.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class StackedHourglass(nn.Module): 7 | ''' 8 | inputs --- [B, 64, 1/4D, 1/4H, 1/4W] 9 | ''' 10 | 11 | def __init__(self, max_disp): 12 | super().__init__() 13 | 14 | self.conv0 = nn.Sequential( 15 | Conv3dBn(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=True), 16 | Conv3dBn(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=True) 17 | ) 18 | self.conv1 = nn.Sequential( 19 | Conv3dBn(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=True), 20 | Conv3dBn(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=False) 21 | ) 22 | self.hourglass1 = Hourglass() 23 | self.hourglass2 = Hourglass() 24 | self.hourglass3 = Hourglass() 25 | 26 | self.out1 = nn.Sequential( 27 | Conv3dBn(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=True), 28 | nn.Conv3d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1, dilation=1, bias=False) 29 | ) 30 | self.out2 = nn.Sequential( 31 | Conv3dBn(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=True), 32 | nn.Conv3d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1, dilation=1, bias=False) 33 | ) 34 | self.out3 = nn.Sequential( 35 | Conv3dBn(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=True), 36 | nn.Conv3d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1, dilation=1, bias=False) 37 | ) 38 | 39 | self.regression = DisparityRegression(max_disp) 40 | 41 | def forward(self, inputs, out_size): 42 | 43 | conv0_out = self.conv0(inputs) # [B, 32, 1/4D, 1/4H, 1/4W] 44 | conv1_out = self.conv1(conv0_out) 45 | conv1_out = conv0_out + conv1_out # [B, 32, 1/4D, 1/4H, 1/4W] 46 | 47 | hourglass1_out1, hourglass1_out3, hourglass1_out4 = self.hourglass1(conv1_out, scale1=None, scale2=None, scale3=conv1_out) 48 | hourglass2_out1, hourglass2_out3, hourglass2_out4 = self.hourglass2(hourglass1_out4, scale1=hourglass1_out3, scale2=hourglass1_out1, scale3=conv1_out) 49 | hourglass3_out1, hourglass3_out3, hourglass3_out4 = self.hourglass3(hourglass2_out4, scale1=hourglass2_out3, scale2=hourglass1_out1, scale3=conv1_out) 50 | 51 | out1 = self.out1(hourglass1_out4) # [B, 1, 1/4D, 1/4H, 1/4W] 52 | out2 = self.out2(hourglass2_out4) + out1 53 | out3 = self.out3(hourglass3_out4) + out2 54 | 55 | cost1 = F.upsample(out1, size=out_size, mode='trilinear').squeeze(dim=1) # [B, D, H, W] 56 | cost2 = F.upsample(out2, size=out_size, mode='trilinear').squeeze(dim=1) # [B, D, H, W] 57 | cost3 = F.upsample(out3, size=out_size, mode='trilinear').squeeze(dim=1) # [B, D, H, W] 58 | 59 | prob1 = F.softmax(-cost1, dim=1) # [B, D, H, W] 60 | prob2 = F.softmax(-cost2, dim=1) 61 | prob3 = F.softmax(-cost3, dim=1) 62 | 63 | disp1 = self.regression(prob1) 64 | disp2 = self.regression(prob2) 65 | disp3 = self.regression(prob3) 66 | 67 | return disp1, disp2, disp3 68 | 69 | 70 | class DisparityRegression(nn.Module): 71 | 72 | def __init__(self, max_disp): 73 | super().__init__() 74 | 75 | self.disp_score = torch.range(0, max_disp - 1) # [D] 76 | self.disp_score = self.disp_score.unsqueeze(0).unsqueeze(2).unsqueeze(3) # [1, D, 1, 1] 77 | 78 | def forward(self, prob): 79 | disp_score = self.disp_score.expand_as(prob).type_as(prob) # [B, D, H, W] 80 | out = torch.sum(disp_score * prob, dim=1) # [B, H, W] 81 | return out 82 | 83 | 84 | class Hourglass(nn.Module): 85 | 86 | def __init__(self): 87 | super().__init__() 88 | 89 | self.net1 = nn.Sequential( 90 | Conv3dBn(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1, dilation=1, use_relu=True), 91 | Conv3dBn(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=False) 92 | ) 93 | self.net2 = nn.Sequential( 94 | Conv3dBn(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1, dilation=1, use_relu=True), 95 | Conv3dBn(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=True) 96 | ) 97 | self.net3 = nn.Sequential( 98 | nn.ConvTranspose3d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False), 99 | nn.BatchNorm3d(num_features=64) 100 | # nn.ReLU(inplace=True) 101 | ) 102 | self.net4 = nn.Sequential( 103 | nn.ConvTranspose3d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False), 104 | nn.BatchNorm3d(num_features=32) 105 | ) 106 | 107 | def forward(self, inputs, scale1=None, scale2=None, scale3=None): 108 | net1_out = self.net1(inputs) # [B, 64, 1/8D, 1/8H, 1/8W] 109 | 110 | if scale1 is not None: 111 | net1_out = F.relu(net1_out + scale1, inplace=True) 112 | else: 113 | net1_out = F.relu(net1_out, inplace=True) 114 | 115 | net2_out = self.net2(net1_out) # [B, 64, 1/16D, 1/16H, 1/16W] 116 | net3_out = self.net3(net2_out) # [B, 64, 1/8D, 1/8H, 1/8W] 117 | 118 | if scale2 is not None: 119 | net3_out = F.relu(net3_out + scale2, inplace=True) 120 | else: 121 | net3_out = F.relu(net3_out + net1_out, inplace=True) 122 | 123 | net4_out = self.net4(net3_out) 124 | 125 | if scale3 is not None: 126 | net4_out = net4_out + scale3 127 | 128 | return net1_out, net3_out, net4_out 129 | 130 | 131 | class Conv3dBn(nn.Module): 132 | 133 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, use_relu=True): 134 | super().__init__() 135 | 136 | net = [nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=False), 137 | nn.BatchNorm3d(out_channels)] 138 | if use_relu: 139 | net.append(nn.ReLU(inplace=True)) 140 | 141 | self.net = nn.Sequential(*net) 142 | 143 | def forward(self, inputs): 144 | out = self.net(inputs) 145 | return out 146 | -------------------------------------------------------------------------------- /pic/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinglittleQ/PSMNet/897c45facaa176e70ab43cb6aefd5427da1f1b2a/pic/01.png -------------------------------------------------------------------------------- /pic/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinglittleQ/PSMNet/897c45facaa176e70ab43cb6aefd5427da1f1b2a/pic/02.png -------------------------------------------------------------------------------- /pic/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinglittleQ/PSMNet/897c45facaa176e70ab43cb6aefd5427da1f1b2a/pic/03.png -------------------------------------------------------------------------------- /pic/04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinglittleQ/PSMNet/897c45facaa176e70ab43cb6aefd5427da1f1b2a/pic/04.png -------------------------------------------------------------------------------- /pic/disp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinglittleQ/PSMNet/897c45facaa176e70ab43cb6aefd5427da1f1b2a/pic/disp.png -------------------------------------------------------------------------------- /pic/error3px.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinglittleQ/PSMNet/897c45facaa176e70ab43cb6aefd5427da1f1b2a/pic/error3px.png -------------------------------------------------------------------------------- /pic/individualImage (1).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinglittleQ/PSMNet/897c45facaa176e70ab43cb6aefd5427da1f1b2a/pic/individualImage (1).png -------------------------------------------------------------------------------- /pic/left.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinglittleQ/PSMNet/897c45facaa176e70ab43cb6aefd5427da1f1b2a/pic/left.png -------------------------------------------------------------------------------- /pic/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinglittleQ/PSMNet/897c45facaa176e70ab43cb6aefd5427da1f1b2a/pic/loss.png -------------------------------------------------------------------------------- /pic/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinglittleQ/PSMNet/897c45facaa176e70ab43cb6aefd5427da1f1b2a/pic/model.png -------------------------------------------------------------------------------- /pic/virusalize01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinglittleQ/PSMNet/897c45facaa176e70ab43cb6aefd5427da1f1b2a/pic/virusalize01.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv_contrib_python_headless==3.4.3.18 2 | torchvision==0.2.1 3 | torch==0.4.0 4 | matplotlib==2.2.2 5 | numpy==1.14.3 6 | tensorboardX==1.4 7 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | python train.py --logdir log/exp01 \ 2 | --num-epochs 2 \ 3 | --batch-size 4 4 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.PSMnet import PSMNet 3 | 4 | torch.manual_seed(2.0) 5 | 6 | model = PSMNet(16).cuda() 7 | left = torch.randn(2, 3, 256, 256).cuda() 8 | right = torch.randn(2, 3, 256, 256).cuda() 9 | print(left[:, :, 0, 0]) 10 | 11 | out1, out2, out3 = model(left, right) 12 | print(out2[0, :3, :3]) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torchvision.transforms as T 10 | from torch.utils.data import DataLoader 11 | from models.PSMnet import PSMNet 12 | from models.smoothloss import SmoothL1Loss 13 | from dataloader.KITTI2015_loader import KITTI2015, RandomCrop, ToTensor, Normalize, Pad 14 | 15 | import tensorboardX as tX 16 | 17 | import matplotlib 18 | matplotlib.use('agg') 19 | import matplotlib.pyplot as plt 20 | 21 | 22 | parser = argparse.ArgumentParser(description='PSMNet') 23 | parser.add_argument('--maxdisp', type=int, default=192, help='max diparity') 24 | parser.add_argument('--logdir', default='log/runs', help='log directory') 25 | parser.add_argument('--datadir', default='../../data/KITTI2015', help='data directory') 26 | parser.add_argument('--cuda', type=int, default=0, help='gpu number') 27 | parser.add_argument('--batch-size', type=int, default=8, help='batch size') 28 | parser.add_argument('--validate-batch-size', type=int, default=2, help='batch size') 29 | parser.add_argument('--log-per-step', type=int, default=1, help='log per step') 30 | parser.add_argument('--save-per-epoch', type=int, default=1, help='save model per epoch') 31 | parser.add_argument('--model-dir', default='checkpoint', help='directory where save model checkpoint') 32 | parser.add_argument('--model-path', default=None, help='path of model to load') 33 | # parser.add_argument('--start-step', type=int, default=0, help='number of steps at starting') 34 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 35 | parser.add_argument('--num-epochs', type=int, default=300, help='number of training epochs') 36 | parser.add_argument('--num-workers', type=int, default=8, help='num workers in loading data') 37 | # parser.add_argument('--') 38 | 39 | args = parser.parse_args() 40 | 41 | 42 | # imagenet 43 | mean = [0.406, 0.456, 0.485] 44 | std = [0.225, 0.224, 0.229] 45 | device_ids = [0, 1, 2, 3] 46 | 47 | writer = tX.SummaryWriter(log_dir=args.logdir, comment='FSMNet') 48 | device = torch.device('cuda') 49 | print(device) 50 | 51 | 52 | def main(args): 53 | 54 | train_transform = T.Compose([RandomCrop([256, 512]), Normalize(mean, std), ToTensor()]) 55 | train_dataset = KITTI2015(args.datadir, mode='train', transform=train_transform) 56 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 57 | 58 | validate_transform = T.Compose([Normalize(mean, std), ToTensor(), Pad(384, 1248)]) 59 | validate_dataset = KITTI2015(args.datadir, mode='validate', transform=validate_transform) 60 | validate_loader = DataLoader(validate_dataset, batch_size=args.validate_batch_size, num_workers=args.num_workers) 61 | 62 | step = 0 63 | best_error = 100.0 64 | 65 | model = PSMNet(args.maxdisp).to(device) 66 | model = nn.DataParallel(model, device_ids=device_ids) 67 | criterion = SmoothL1Loss().to(device) 68 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 69 | 70 | if args.model_path is not None: 71 | state = torch.load(args.model_path) 72 | model.load_state_dict(state['state_dict']) 73 | optimizer.load_state_dict(state['optimizer']) 74 | step = state['step'] 75 | best_error = state['error'] 76 | print('load model from {}'.format(args.model_path)) 77 | 78 | print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 79 | 80 | for epoch in range(1, args.num_epochs + 1): 81 | model.train() 82 | step = train(model, train_loader, optimizer, criterion, step) 83 | adjust_lr(optimizer, epoch) 84 | 85 | if epoch % args.save_per_epoch == 0: 86 | model.eval() 87 | error = validate(model, validate_loader, epoch) 88 | best_error = save(model, optimizer, epoch, step, error, best_error) 89 | 90 | 91 | def validate(model, validate_loader, epoch): 92 | ''' 93 | validate 40 image pairs 94 | ''' 95 | num_batches = len(validate_loader) 96 | idx = np.random.randint(num_batches) 97 | 98 | avg_error = 0.0 99 | for i, batch in enumerate(validate_loader): 100 | left_img = batch['left'].to(device) 101 | right_img = batch['right'].to(device) 102 | target_disp = batch['disp'].to(device) 103 | 104 | mask = (target_disp > 0) 105 | mask = mask.detach_() 106 | 107 | with torch.no_grad(): 108 | _, _, disp = model(left_img, right_img) 109 | 110 | delta = torch.abs(disp[mask] - target_disp[mask]) 111 | error_mat = (((delta >= 3.0) + (delta >= 0.05 * (target_disp[mask]))) == 2) 112 | error = torch.sum(error_mat).item() / torch.numel(disp[mask]) * 100 113 | 114 | avg_error += error 115 | if i == idx: 116 | left_save = left_img 117 | disp_save = disp 118 | 119 | avg_error = avg_error / num_batches 120 | print('epoch: {:03} | 3px-error: {:.5}%'.format(epoch, avg_error)) 121 | writer.add_scalar('error/3px', avg_error, epoch) 122 | save_image(left_save[0], disp_save[0], epoch) 123 | 124 | return avg_error 125 | 126 | 127 | def save_image(left_image, disp, epoch): 128 | for i in range(3): 129 | left_image[i] = left_image[i] * std[i] + mean[i] 130 | b, r = left_image[0], left_image[2] 131 | left_image[0] = r # BGR --> RGB 132 | left_image[2] = b 133 | # left_image = torch.from_numpy(left_image.cpu().numpy()[::-1]) 134 | 135 | disp_img = disp.detach().cpu().numpy() 136 | fig = plt.figure(12.84, 3.84) 137 | plt.axis('off') # hide axis 138 | plt.imshow(disp_img) 139 | plt.colorbar() 140 | 141 | writer.add_figure('image/disp', fig, global_step=epoch) 142 | writer.add_image('image/left', left_image, global_step=epoch) 143 | 144 | 145 | def train(model, train_loader, optimizer, criterion, step): 146 | ''' 147 | train one epoch 148 | ''' 149 | for batch in train_loader: 150 | step += 1 151 | optimizer.zero_grad() 152 | 153 | left_img = batch['left'].to(device) 154 | right_img = batch['right'].to(device) 155 | target_disp = batch['disp'].to(device) 156 | 157 | mask = (target_disp > 0) 158 | mask = mask.detach_() 159 | 160 | disp1, disp2, disp3 = model(left_img, right_img) 161 | loss1, loss2, loss3 = criterion(disp1[mask], disp2[mask], disp3[mask], target_disp[mask]) 162 | total_loss = 0.5 * loss1 + 0.7 * loss2 + 1.0 * loss3 163 | 164 | total_loss.backward() 165 | optimizer.step() 166 | 167 | # print(step) 168 | 169 | if step % args.log_per_step == 0: 170 | writer.add_scalar('loss/loss1', loss1, step) 171 | writer.add_scalar('loss/loss2', loss2, step) 172 | writer.add_scalar('loss/loss3', loss3, step) 173 | writer.add_scalar('loss/total_loss', total_loss, step) 174 | print('step: {:05} | total loss: {:.5} | loss1: {:.5} | loss2: {:.5} | loss3: {:.5}'.format(step, total_loss.item(), loss1.item(), loss2.item(), loss3.item())) 175 | 176 | return step 177 | 178 | 179 | def adjust_lr(optimizer, epoch): 180 | if epoch == 200: 181 | lr = 0.0001 182 | for param_group in optimizer.param_groups: 183 | param_group['lr'] = lr 184 | 185 | 186 | def save(model, optimizer, epoch, step, error, best_error): 187 | path = os.path.join(args.model_dir, '{:03}.ckpt'.format(epoch)) 188 | # torch.save(model.state_dict(), path) 189 | # model.save_state_dict(path) 190 | 191 | state = {} 192 | state['state_dict'] = model.state_dict() 193 | state['optimizer'] = optimizer.state_dict() 194 | state['error'] = error 195 | state['epoch'] = epoch 196 | state['step'] = step 197 | 198 | torch.save(state, path) 199 | print('save model at epoch{}'.format(epoch)) 200 | 201 | if error < best_error: 202 | best_error = error 203 | best_path = os.path.join(args.model_dir, 'best_model.ckpt'.format(epoch)) 204 | shutil.copyfile(path, best_path) 205 | print('best model in epoch {}'.format(epoch)) 206 | 207 | return best_error 208 | 209 | 210 | if __name__ == '__main__': 211 | main(args) 212 | writer.close() 213 | --------------------------------------------------------------------------------