├── figure └── figure.png ├── LICENSE ├── dataloader ├── readpfm.py ├── KITTI2012loader.py ├── ETH3D_loader.py ├── vKITTI_loader.py ├── middlebury_loader.py ├── KITTIloader.py └── sceneflow_loader.py ├── README.md ├── networks ├── feature_extraction.py ├── stackhourglass.py ├── vgg.py ├── submodule.py ├── resnet.py ├── U_net.py └── Aggregator.py ├── test_eth3d.py ├── test_kitti.py ├── test_middlebury.py ├── train_baseline.py ├── train_adaptor.py ├── retrain_CostAggregation.py └── loss_functions.py /figure/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpadeLiu/Graft-PSMNet/HEAD/figure/figure.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 qqwweee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dataloader/readpfm.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | import re 3 | import numpy as np 4 | 5 | 6 | def readPFM(file): 7 | file = open(file, 'rb') 8 | 9 | color = None 10 | width = None 11 | height = None 12 | scale = None 13 | endian = None 14 | 15 | header = file.readline().rstrip() 16 | header = header.decode('utf-8') 17 | if header == 'PF': 18 | color = True 19 | elif header == 'Pf': 20 | color = False 21 | else: 22 | raise Exception('Not a PFM file.') 23 | 24 | dim_match = re.match('^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 25 | if dim_match: 26 | width, height = map(int, dim_match.groups()) 27 | else: 28 | raise Exception('Malformed PFM header.') 29 | 30 | scale = float(file.readline().rstrip().decode('utf-8')) 31 | if scale < 0: 32 | endian = '<' 33 | scale = -scale 34 | else: 35 | endian = '>' 36 | 37 | data = np.fromfile(file, endian + 'f') 38 | shape = (height, width, 3) if color else (height, width) 39 | 40 | data = np.reshape(data, shape) 41 | data = np.flipud(data) 42 | 43 | return data, scale 44 | 45 | 46 | if __name__ == '__main__': 47 | img_path = \ 48 | '/media/data/dataset/SceneFlow/driving_frames_cleanpass/15mm_focallength/scene_backwards/fast/left/0100.png' 49 | disp_path = img_path.replace('driving_frames_cleanpass', 'driving_disparity').replace('png', 'pfm') 50 | 51 | data, scale = readPFM(disp_path) 52 | dataL = np.ascontiguousarray(data, dtype=np.float32)\ 53 | 54 | import matplotlib.pyplot as plt 55 | plt.imshow(dataL) 56 | plt.show() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### GraftNet: Towards Domain Generalized Stereo Matching with a Broad-Spectrum and Task-Oriented Feature 2 | 3 | 4 | 5 | #### Dependencies: 6 | - Python 3.6 7 | - PyTorch 1.7.0 8 | - torchvision 0.3.0 9 | - [VGG trained on ImageNet](https://download.pytorch.org/models/vgg16-397923af.pth) 10 | 11 | #### Datasets: 12 | - [SceneFlow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 13 | - [KITTI stereo 2015](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo) 14 | - [KITTI stereo 2012](http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php?benchmark=stereo) 15 | - [Middlebury v3](https://vision.middlebury.edu/stereo/submit3/) 16 | - [ETH3D](https://www.eth3d.net/datasets#low-res-two-view) 17 | 18 | #### Training Steps: 19 | ##### 1. Train A Basic Stereo Matching Network: 20 | ```bash 21 | python train_baseline.py --data_path (your SceneFlow data folder) 22 | ``` 23 | ##### 2. Graft VGG's Feature and Train the Feature Adaptor: 24 | ```bash 25 | python train_adaptor.py --data_path (your SceneFlow data folder) 26 | ``` 27 | ##### 3. Retrain the Cost Aggregation Module: 28 | ```bash 29 | python retrain_CostAggregation.py --data_path (your SceneFlow data folder) 30 | ``` 31 | 32 | #### Evaluation: 33 | ##### Evaluate on KITTI: 34 | ```bash 35 | python test_kitti.py --data_path (your KITTI training data folder) --load_path (the path of the final model) 36 | ``` 37 | ##### Evaluate on Middlebury-H: 38 | ```bash 39 | python test_middlebury.py --data_path (your Middlebury training data folder) --load_path (the path of the final model) 40 | ``` 41 | ##### Evaluate on ETH3D: 42 | ```bash 43 | python test_middlebury.py --data_path (your Middlebury training data folder) --load_path (the path of the final model) 44 | ``` 45 | 46 | #### Pretrained Models: 47 | [Google Drive](https://drive.google.com/drive/folders/1Ud9-HpHSXE5qMRQ17Fs8BNLfyE2VW03U?usp=sharing) 48 | -------------------------------------------------------------------------------- /dataloader/KITTI2012loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torchvision.transforms as transforms 3 | import os 4 | from PIL import Image 5 | import random 6 | import numpy as np 7 | 8 | 9 | IMG_EXTENSIONS= [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP' 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | 19 | def kt2012_loader(filepath): 20 | 21 | left_path = os.path.join(filepath, 'colored_0') 22 | right_path = os.path.join(filepath, 'colored_1') 23 | displ_path = os.path.join(filepath, 'disp_occ') 24 | 25 | total_name = [name for name in os.listdir(left_path) if name.find('_10') > -1] 26 | train_name = total_name[:160] 27 | val_name = total_name[160:] 28 | 29 | train_left = [] 30 | train_right = [] 31 | train_displ = [] 32 | for name in train_name: 33 | train_left.append(os.path.join(left_path, name)) 34 | train_right.append(os.path.join(right_path, name)) 35 | train_displ.append(os.path.join(displ_path, name)) 36 | 37 | val_left = [] 38 | val_right = [] 39 | val_displ = [] 40 | for name in val_name: 41 | val_left.append(os.path.join(left_path, name)) 42 | val_right.append(os.path.join(right_path, name)) 43 | val_displ.append(os.path.join(displ_path, name)) 44 | 45 | return train_left, train_right, train_displ, val_left, val_right, val_displ 46 | 47 | 48 | def img_loader(path): 49 | return Image.open(path).convert('RGB') 50 | 51 | 52 | def disparity_loader(path): 53 | return Image.open(path) 54 | 55 | 56 | transform = transforms.Compose([ 57 | transforms.ToTensor(), 58 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 59 | ]) 60 | 61 | 62 | class myDataset(data.Dataset): 63 | 64 | def __init__(self, left, right, left_disp, training, imgloader=img_loader, disploader=disparity_loader): 65 | self.left = left 66 | self.right = right 67 | self.left_disp = left_disp 68 | self.imgloader = imgloader 69 | self.disploader = disploader 70 | self.training = training 71 | 72 | def __getitem__(self, index): 73 | left = self.left[index] 74 | right = self.right[index] 75 | left_disp = self.left_disp[index] 76 | 77 | limg = self.imgloader(left) 78 | rimg = self.imgloader(right) 79 | ldisp = self.disploader(left_disp) 80 | 81 | if self.training: 82 | w, h = limg.size 83 | tw, th = 512, 256 84 | 85 | x1 = random.randint(0, w - tw) 86 | y1 = random.randint(0, h - th) 87 | 88 | limg = limg.crop((x1, y1, x1 + tw, y1 + th)) 89 | rimg = rimg.crop((x1, y1, x1 + tw, y1 + th)) 90 | ldisp = np.ascontiguousarray(ldisp, dtype=np.float32)/256 91 | ldisp = ldisp[y1:y1 + th, x1:x1 + tw] 92 | 93 | limg = transform(limg) 94 | rimg = transform(rimg) 95 | 96 | return limg, rimg, ldisp 97 | 98 | else: 99 | w, h = limg.size 100 | 101 | limg = limg.crop((w-1232, h-368, w, h)) 102 | rimg = rimg.crop((w-1232, h-368, w, h)) 103 | ldisp = ldisp.crop((w-1232, h-368, w, h)) 104 | ldisp = np.ascontiguousarray(ldisp, dtype=np.float32)/256 105 | 106 | limg = transform(limg) 107 | rimg = transform(rimg) 108 | 109 | return limg, rimg, ldisp 110 | 111 | def __len__(self): 112 | return len(self.left) -------------------------------------------------------------------------------- /networks/feature_extraction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.data 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from torchvision import models 7 | import math 8 | import numpy as np 9 | import torchvision.transforms as transforms 10 | import PIL 11 | import os 12 | import matplotlib.pyplot as plt 13 | from networks.resnet import ResNet, Bottleneck, BasicBlock_Res 14 | from networks.vgg import vgg16 15 | from collections import OrderedDict 16 | 17 | 18 | class VGG_Feature(nn.Module): 19 | def __init__(self, fixed_param): 20 | super(VGG_Feature, self).__init__() 21 | 22 | self.fe = vgg16(pretrained=False) 23 | 24 | self.fe.load_state_dict( 25 | torch.load('networks/vgg16-397923af.pth')) 26 | 27 | features = self.fe.features 28 | 29 | self.to_feat = nn.Sequential() 30 | 31 | for i in range(15): 32 | self.to_feat.add_module(str(i), features[i]) 33 | 34 | if fixed_param: 35 | for p in self.to_feat.parameters(): 36 | p.requires_grad = False 37 | 38 | def forward(self, x): 39 | feature = self.to_feat(x) 40 | 41 | # feature = F.interpolate(feature, scale_factor=0.5, mode='bilinear', align_corners=True) 42 | 43 | return feature 44 | 45 | 46 | class VGG_Bn_Feature(nn.Module): 47 | def __init__(self): 48 | super(VGG_Bn_Feature, self).__init__() 49 | 50 | features = models.vgg16_bn(pretrained=True).cuda().eval().features 51 | self.to_feat = nn.Sequential() 52 | # for i in range(8): 53 | # self.to_feat.add_module(str(i), features[i]) 54 | 55 | for i in range(15): 56 | self.to_feat.add_module(str(i), features[i]) 57 | 58 | for p in self.to_feat.parameters(): 59 | p.requires_grad = False 60 | 61 | def forward(self, x): 62 | feature = self.to_feat(x) 63 | 64 | # feature = F.interpolate(feature, scale_factor=0.5, mode='bilinear', align_corners=True) 65 | 66 | return feature 67 | 68 | 69 | class Res18(nn.Module): 70 | def __init__(self): 71 | super(Res18, self).__init__() 72 | 73 | self.fe = ResNet(BasicBlock_Res, [2, 2, 2, 2]) 74 | 75 | # self.fe = ResNet(Bottleneck, [3, 4, 6, 3]) 76 | 77 | for p in self.fe.parameters(): 78 | p.requires_grad = False 79 | 80 | self.fe.load_state_dict( 81 | torch.load('networks/resnet18-5c106cde.pth')) 82 | 83 | def forward(self, x): 84 | 85 | self.fe.eval() 86 | 87 | with torch.no_grad(): 88 | feature = self.fe(x) 89 | 90 | return feature 91 | 92 | 93 | class Res50(nn.Module): 94 | def __init__(self): 95 | super(Res50, self).__init__() 96 | 97 | self.fe = ResNet(Bottleneck, [3, 4, 6, 3]) 98 | 99 | for p in self.fe.parameters(): 100 | p.requires_grad = False 101 | 102 | # self.fe.load_state_dict( 103 | # torch.load('networks/resnet50-19c8e357.pth')) 104 | self.fe.load_state_dict( 105 | torch.load('networks/DenseCL_R50_imagenet.pth')) 106 | 107 | def forward(self, x): 108 | 109 | self.fe.eval() 110 | 111 | with torch.no_grad(): 112 | feature = self.fe(x) 113 | 114 | return feature 115 | 116 | 117 | if __name__ == '__main__': 118 | os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID" 119 | os.environ['CUDA_VISIBLE_DEVICES'] = "2" 120 | from collections import OrderedDict 121 | ckpt = torch.load('selfTrainVGG_withDA.pth') 122 | new_dict = OrderedDict() 123 | for k, v in ckpt.items(): 124 | new_k = k.replace('module.', '') 125 | new_dict[new_k] = v 126 | 127 | torch.save(new_dict, 'selfTrainVGG_withDA.pth') -------------------------------------------------------------------------------- /dataloader/ETH3D_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from dataloader import readpfm as rp 4 | import dataloader.preprocess 5 | import torch.utils.data as data 6 | import torchvision.transforms as transforms 7 | import numpy as np 8 | import random 9 | 10 | IMG_EXTENSIONS= [ 11 | '.jpg', '.JPG', '.jpeg', '.JPEG', 12 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP' 13 | ] 14 | 15 | 16 | def is_image_file(filename): 17 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 18 | 19 | 20 | # filepath = '/media/data/dataset/ETH3D/' 21 | def et_loader(filepath): 22 | 23 | left_img = [] 24 | right_img = [] 25 | disp_gt = [] 26 | occ_mask = [] 27 | 28 | img_path = os.path.join(filepath, 'two_view_training') 29 | gt_path = os.path.join(filepath, 'two_view_training_gt') 30 | 31 | for c in os.listdir(img_path): 32 | img_cpath = os.path.join(img_path, c) 33 | gt_cpath = os.path.join(gt_path, c) 34 | 35 | left_img.append(os.path.join(img_cpath, 'im0.png')) 36 | right_img.append(os.path.join(img_cpath, 'im1.png')) 37 | disp_gt.append(os.path.join(gt_cpath, 'disp0GT.pfm')) 38 | occ_mask.append(os.path.join(gt_cpath, 'mask0nocc.png')) 39 | 40 | return left_img, right_img, disp_gt, occ_mask, 41 | 42 | 43 | def img_loader(path): 44 | return Image.open(path).convert('RGB') 45 | 46 | 47 | def disparity_loader(path): 48 | return rp.readPFM(path) 49 | 50 | 51 | class myDataset(data.Dataset): 52 | 53 | def __init__(self, left, right, disp_gt, occ_mask, training, imgloader=img_loader, dploader = disparity_loader): 54 | self.left = left 55 | self.right = right 56 | self.disp_gt = disp_gt 57 | self.occ_mask = occ_mask 58 | self.imgloader = imgloader 59 | self.dploader = dploader 60 | self.training = training 61 | self.img_transorm = transforms.Compose([ 62 | transforms.ToTensor(), 63 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 64 | 65 | def __getitem__(self, index): 66 | left = self.left[index] 67 | right = self.right[index] 68 | disp_L = self.disp_L[index] 69 | disp_R = self.disp_R[index] 70 | 71 | left_img = self.imgloader(left) 72 | right_img = self.imgloader(right) 73 | dataL, _ = self.dploader(disp_L) 74 | dataL = np.ascontiguousarray(dataL, dtype=np.float32) 75 | dataR, _ = self.dploader(disp_R) 76 | dataR = np.ascontiguousarray(dataR, dtype=np.float32) 77 | 78 | if self.training: 79 | w, h = left_img.size 80 | tw, th = 512, 256 81 | x1 = random.randint(0, w - tw) 82 | y1 = random.randint(0, h - th) 83 | 84 | left_img = left_img.crop((x1, y1, x1+tw, y1+th)) 85 | right_img = right_img.crop((x1, y1, x1+tw, y1+th)) 86 | dataL = dataL[y1:y1+th, x1:x1+tw] 87 | dataR = dataR[y1:y1+th, x1:x1+tw] 88 | 89 | left_img = self.img_transorm(left_img) 90 | right_img = self.img_transorm(right_img) 91 | 92 | return left_img, right_img, dataL, dataR 93 | 94 | else: 95 | w, h = left_img.size 96 | left_img = left_img.crop((w-960, h-544, w, h)) 97 | right_img = right_img.crop((w-960, h-544, w, h)) 98 | 99 | left_img = self.img_transorm(left_img) 100 | right_img = self.img_transorm(right_img) 101 | 102 | dataL = Image.fromarray(dataL).crop((w-960, h-544, w, h)) 103 | dataL = np.ascontiguousarray(dataL) 104 | dataR = Image.fromarray(dataR).crop((w-960, h-544, w, h)) 105 | dataR = np.ascontiguousarray(dataR) 106 | 107 | return left_img, right_img, dataL, dataR 108 | 109 | def __len__(self): 110 | return len(self.left) 111 | 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /test_eth3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | from torch.autograd import grad as Grad 6 | from torchvision import transforms 7 | import os 8 | import copy 9 | import skimage.io 10 | from collections import OrderedDict 11 | from tqdm import tqdm, trange 12 | from PIL import Image 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | import argparse 16 | 17 | from dataloader import ETH3D_loader as et 18 | from dataloader.readpfm import readPFM 19 | import networks.Aggregator as Agg 20 | import networks.feature_extraction as FE 21 | import networks.U_net as un 22 | 23 | 24 | parser = argparse.ArgumentParser(description='GraftNet') 25 | parser.add_argument('--no_cuda', action='store_true', default=False) 26 | parser.add_argument('--gpu_id', type=str, default='2') 27 | parser.add_argument('--seed', type=str, default=0) 28 | parser.add_argument('--data_path', type=str, default='/media/data/dataset/ETH3D/') 29 | parser.add_argument('--load_path', type=str, default='trained_models/checkpoint_final_10epoch.tar') 30 | parser.add_argument('--max_disp', type=int, default=192) 31 | args = parser.parse_args() 32 | 33 | 34 | if not args.no_cuda: 35 | os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID" 36 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 37 | cuda = torch.cuda.is_available() 38 | 39 | 40 | all_limg, all_rimg, all_disp, all_mask = et.et_loader(args.data_path) 41 | 42 | 43 | fe_model = FE.VGG_Feature(fixed_param=True).eval() 44 | adaptor = un.U_Net_v4(img_ch=256, output_ch=64).eval() 45 | agg_model = Agg.PSMAggregator(args.max_disp, udc=True).eval() 46 | 47 | if cuda: 48 | fe_model = nn.DataParallel(fe_model.cuda()) 49 | adaptor = nn.DataParallel(adaptor.cuda()) 50 | agg_model = nn.DataParallel(agg_model.cuda()) 51 | 52 | adaptor.load_state_dict(torch.load(args.load_path)['fa_net']) 53 | agg_model.load_state_dict(torch.load(args.load_path)['net']) 54 | 55 | 56 | pred_mae = 0 57 | pred_op = 0 58 | for i in trange(len(all_limg)): 59 | limg = Image.open(all_limg[i]).convert('RGB') 60 | rimg = Image.open(all_rimg[i]).convert('RGB') 61 | 62 | w, h = limg.size 63 | wi, hi = (w // 16 + 1) * 16, (h // 16 + 1) * 16 64 | limg = limg.crop((w - wi, h - hi, w, h)) 65 | rimg = rimg.crop((w - wi, h - hi, w, h)) 66 | 67 | limg_tensor = transforms.Compose([ 68 | transforms.ToTensor(), 69 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])(limg) 70 | rimg_tensor = transforms.Compose([ 71 | transforms.ToTensor(), 72 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])(rimg) 73 | limg_tensor = limg_tensor.unsqueeze(0).cuda() 74 | rimg_tensor = rimg_tensor.unsqueeze(0).cuda() 75 | 76 | disp_gt, _ = readPFM(all_disp[i]) 77 | disp_gt = np.ascontiguousarray(disp_gt, dtype=np.float32) 78 | disp_gt[disp_gt == np.inf] = 0 79 | gt_tensor = torch.FloatTensor(disp_gt).unsqueeze(0).unsqueeze(0).cuda() 80 | 81 | occ_mask = np.ascontiguousarray(Image.open(all_mask[i])) 82 | 83 | with torch.no_grad(): 84 | left_fea = fe_model(limg_tensor) 85 | right_fea = fe_model(rimg_tensor) 86 | 87 | left_fea = adaptor(left_fea) 88 | right_fea = adaptor(right_fea) 89 | 90 | pred_disp = agg_model(left_fea, right_fea, gt_tensor, training=False) 91 | pred_disp = pred_disp[:, hi - h:, wi - w:] 92 | 93 | predict_np = pred_disp.squeeze().cpu().numpy() 94 | 95 | op_thresh = 1 96 | mask = (disp_gt > 0) & (occ_mask == 255) 97 | # mask = disp_gt > 0 98 | error = np.abs(predict_np * mask.astype(np.float32) - disp_gt * mask.astype(np.float32)) 99 | 100 | pred_error = np.abs(predict_np * mask.astype(np.float32) - disp_gt * mask.astype(np.float32)) 101 | pred_op += np.sum(pred_error > op_thresh) / np.sum(mask) 102 | pred_mae += np.mean(pred_error[mask]) 103 | 104 | print(pred_mae / len(all_limg)) 105 | print(pred_op / len(all_limg)) -------------------------------------------------------------------------------- /networks/stackhourglass.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.data 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | import math 8 | from networks.submodule import convbn, convbn_3d, DisparityRegression 9 | 10 | 11 | class hourglass(nn.Module): 12 | def __init__(self, inplanes): 13 | super(hourglass, self).__init__() 14 | 15 | self.conv1 = nn.Sequential(convbn_3d(inplanes, inplanes*2, kernel_size=3, stride=2, pad=1), 16 | nn.ReLU(inplace=True)) 17 | 18 | self.conv2 = convbn_3d(inplanes*2, inplanes*2, kernel_size=3, stride=1, pad=1) 19 | 20 | self.conv3 = nn.Sequential(convbn_3d(inplanes*2, inplanes*2, kernel_size=3, stride=2, pad=1), 21 | nn.ReLU(inplace=True)) 22 | 23 | self.conv4 = nn.Sequential(convbn_3d(inplanes*2, inplanes*2, kernel_size=3, stride=1, pad=1), 24 | nn.ReLU(inplace=True)) 25 | 26 | self.conv5 = nn.Sequential(nn.ConvTranspose3d(inplanes*2, inplanes*2, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False), 27 | nn.BatchNorm3d(inplanes*2)) #+conv2 28 | 29 | self.conv6 = nn.Sequential(nn.ConvTranspose3d(inplanes*2, inplanes, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False), 30 | nn.BatchNorm3d(inplanes)) #+x 31 | 32 | def forward(self, x ,presqu, postsqu): 33 | 34 | out = self.conv1(x) #in:1/4 out:1/8 35 | pre = self.conv2(out) #in:1/8 out:1/8 36 | 37 | if postsqu is not None: 38 | pre = F.relu(pre + postsqu, inplace=True) 39 | else: 40 | pre = F.relu(pre, inplace=True) 41 | 42 | # print('pre2', pre.size()) 43 | 44 | out = self.conv3(pre) #in:1/8 out:1/16 45 | out = self.conv4(out) #in:1/16 out:1/16 46 | 47 | # print('out', out.size()) 48 | 49 | if presqu is not None: 50 | post = F.relu(self.conv5(out)+presqu, inplace=True) #in:1/16 out:1/8 51 | else: 52 | post = F.relu(self.conv5(out)+pre, inplace=True) 53 | 54 | out = self.conv6(post) #in:1/8 out:1/4 55 | 56 | return out, pre, post 57 | 58 | 59 | class hourglass_gwcnet(nn.Module): 60 | def __init__(self, inplanes): 61 | super(hourglass_gwcnet, self).__init__() 62 | 63 | self.conv1 = nn.Sequential(convbn_3d(inplanes, inplanes * 2, kernel_size=3, stride=2, pad=1), 64 | nn.ReLU(inplace=True)) 65 | self.conv2 = nn.Sequential(convbn_3d(inplanes * 2, inplanes * 2, kernel_size=3, stride=1, pad=1), 66 | nn.ReLU(inplace=True)) 67 | self.conv3 = nn.Sequential(convbn_3d(inplanes * 2, inplanes * 4, kernel_size=3, stride=2, pad=1), 68 | nn.ReLU(inplace=True)) 69 | self.conv4 = nn.Sequential(convbn_3d(inplanes * 4, inplanes * 4, 3, 1, 1), 70 | nn.ReLU(inplace=True)) 71 | self.conv5 = nn.Sequential(nn.ConvTranspose3d(inplanes * 4, inplanes * 2, kernel_size=3, padding=1, 72 | output_padding=1, stride=2, bias=False), 73 | nn.BatchNorm3d(inplanes * 2)) 74 | self.conv6 = nn.Sequential(nn.ConvTranspose3d(inplanes * 2, inplanes, kernel_size=3, padding=1, 75 | output_padding=1, stride=2, bias=False), 76 | nn.BatchNorm3d(inplanes)) 77 | 78 | self.redir1 = convbn_3d(inplanes, inplanes, kernel_size=1, stride=1, pad=0) 79 | self.redir2 = convbn_3d(inplanes * 2, inplanes * 2, kernel_size=1, stride=1, pad=0) 80 | 81 | def forward(self, x): 82 | 83 | conv1 = self.conv1(x) 84 | conv2 = self.conv2(conv1) 85 | 86 | conv3 = self.conv3(conv2) 87 | conv4 = self.conv4(conv3) 88 | 89 | conv5 = F.relu(self.conv5(conv4) + self.redir2(conv2), inplace=True) 90 | conv6 = F.relu(self.conv6(conv5) + self.redir1(x), inplace=True) 91 | 92 | return conv6 93 | 94 | -------------------------------------------------------------------------------- /test_kitti.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | from torch.autograd import grad as Grad 6 | from torchvision import transforms 7 | import skimage.io 8 | import os 9 | import copy 10 | from collections import OrderedDict 11 | from tqdm import tqdm, trange 12 | from PIL import Image 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | import argparse 16 | 17 | from dataloader import KITTIloader as kt 18 | from dataloader import KITTI2012loader as kt2012 19 | import networks.Aggregator as Agg 20 | import networks.feature_extraction as FE 21 | import networks.U_net as un 22 | 23 | 24 | parser = argparse.ArgumentParser(description='GraftNet') 25 | parser.add_argument('--no_cuda', action='store_true', default=False) 26 | parser.add_argument('--gpu_id', type=str, default='2') 27 | parser.add_argument('--seed', type=str, default=0) 28 | parser.add_argument('--kitti', type=str, default='2015') 29 | parser.add_argument('--data_path', type=str, default='/media/data/dataset/KITTI/data_scene_flow/training/') 30 | parser.add_argument('--load_path', type=str, default='trained_models/checkpoint_final_10epoch.tar') 31 | parser.add_argument('--max_disp', type=int, default=192) 32 | args = parser.parse_args() 33 | 34 | if not args.no_cuda: 35 | os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID" 36 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 37 | cuda = torch.cuda.is_available() 38 | 39 | 40 | if args.kitti == '2015': 41 | all_limg, all_rimg, all_ldisp, test_limg, test_rimg, test_ldisp = kt.kt_loader(args.data_path) 42 | else: 43 | all_limg, all_rimg, all_ldisp, test_limg, test_rimg, test_ldisp = kt2012.kt2012_loader(args.data_path) 44 | 45 | test_limg = all_limg + test_limg 46 | test_rimg = all_rimg + test_rimg 47 | test_ldisp = all_ldisp + test_ldisp 48 | 49 | fe_model = FE.VGG_Feature(fixed_param=True).eval() 50 | adaptor = un.U_Net_v4(img_ch=256, output_ch=64).eval() 51 | agg_model = Agg.PSMAggregator(args.max_disp, udc=True).eval() 52 | 53 | if cuda: 54 | fe_model = nn.DataParallel(fe_model.cuda()) 55 | adaptor = nn.DataParallel(adaptor.cuda()) 56 | agg_model = nn.DataParallel(agg_model.cuda()) 57 | 58 | adaptor.load_state_dict(torch.load(args.load_path)['fa_net']) 59 | agg_model.load_state_dict(torch.load(args.load_path)['net']) 60 | 61 | pred_mae = 0 62 | pred_op = 0 63 | for i in trange(len(test_limg)): 64 | limg = Image.open(test_limg[i]).convert('RGB') 65 | rimg = Image.open(test_rimg[i]).convert('RGB') 66 | 67 | w, h = limg.size 68 | m = 16 69 | wi, hi = (w // m + 1) * m, (h // m + 1) * m 70 | limg = limg.crop((w - wi, h - hi, w, h)) 71 | rimg = rimg.crop((w - wi, h - hi, w, h)) 72 | 73 | transform = transforms.Compose([ 74 | transforms.ToTensor(), 75 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 76 | 77 | limg_tensor = transform(limg) 78 | rimg_tensor = transform(rimg) 79 | limg_tensor = limg_tensor.unsqueeze(0).cuda() 80 | rimg_tensor = rimg_tensor.unsqueeze(0).cuda() 81 | 82 | disp_gt = Image.open(test_ldisp[i]) 83 | disp_gt = np.ascontiguousarray(disp_gt, dtype=np.float32) / 256 84 | gt_tensor = torch.FloatTensor(disp_gt).unsqueeze(0).unsqueeze(0).cuda() 85 | 86 | with torch.no_grad(): 87 | left_fea = fe_model(limg_tensor) 88 | right_fea = fe_model(rimg_tensor) 89 | 90 | left_fea = adaptor(left_fea) 91 | right_fea = adaptor(right_fea) 92 | 93 | pred_disp = agg_model(left_fea, right_fea, gt_tensor, training=False) 94 | pred_disp = pred_disp[:, hi - h:, wi - w:] 95 | 96 | predict_np = pred_disp.squeeze().cpu().numpy() 97 | 98 | op_thresh = 3 99 | mask = (disp_gt > 0) & (disp_gt < args.max_disp) 100 | error = np.abs(predict_np * mask.astype(np.float32) - disp_gt * mask.astype(np.float32)) 101 | 102 | pred_error = np.abs(predict_np * mask.astype(np.float32) - disp_gt * mask.astype(np.float32)) 103 | pred_op += np.sum((pred_error > op_thresh)) / np.sum(mask) 104 | pred_mae += np.mean(pred_error[mask]) 105 | 106 | print(pred_mae / len(test_limg)) 107 | print(pred_op / len(test_limg)) -------------------------------------------------------------------------------- /dataloader/vKITTI_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torchvision.transforms as transforms 3 | import os 4 | from PIL import Image 5 | import random 6 | import numpy as np 7 | 8 | 9 | def vkt_loader(filepath): 10 | all_limg = [] 11 | all_rimg = [] 12 | all_disp = [] 13 | 14 | img_path = os.path.join(filepath, 'vkitti_2.0.3_rgb') 15 | depth_path = os.path.join(filepath, 'vkitti_2.0.3_depth') 16 | 17 | for scene in os.listdir(img_path): 18 | img_scenes_path = os.path.join(img_path, scene, 'clone/frames/rgb') 19 | depth_scenes_path = os.path.join(depth_path, scene, 'clone/frames/depth') 20 | 21 | for name in os.listdir(os.path.join(img_scenes_path, 'Camera_0')): 22 | all_limg.append(os.path.join(img_scenes_path, 'Camera_0', name)) 23 | all_rimg.append(os.path.join(img_scenes_path, 'Camera_1', name)) 24 | all_disp.append(os.path.join(depth_scenes_path, 'Camera_0', 25 | name.replace('jpg', 'png').replace('rgb', 'depth'))) 26 | 27 | total_num = len(all_limg) 28 | train_length = int(total_num * 0.75) 29 | 30 | train_limg = all_limg[:train_length] 31 | train_rimg = all_rimg[:train_length] 32 | train_disp = all_disp[:train_length] 33 | 34 | val_limg = all_limg[train_length:] 35 | val_rimg = all_rimg[train_length:] 36 | val_disp = all_disp[train_length:] 37 | 38 | return train_limg, train_rimg, train_disp, val_limg, val_rimg, val_disp 39 | 40 | 41 | def img_loader(path): 42 | return Image.open(path).convert('RGB') 43 | 44 | 45 | def disparity_loader(path): 46 | return Image.open(path) 47 | 48 | 49 | class vkDataset(data.Dataset): 50 | 51 | def __init__(self, left, right, left_disp, training, imgloader=img_loader, disploader=disparity_loader): 52 | self.left = left 53 | self.right = right 54 | self.left_disp = left_disp 55 | self.imgloader = imgloader 56 | self.disploader = disploader 57 | self.training = training 58 | self.transform = transforms.Compose([ 59 | transforms.ToTensor(), 60 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 61 | ]) 62 | 63 | def __getitem__(self, index): 64 | left = self.left[index] 65 | right = self.right[index] 66 | left_disp = self.left_disp[index] 67 | 68 | limg = self.imgloader(left) 69 | rimg = self.imgloader(right) 70 | ldisp = self.disploader(left_disp) 71 | 72 | if self.training: 73 | w, h = limg.size 74 | tw, th = 512, 256 75 | 76 | x1 = random.randint(0, w - tw) 77 | y1 = random.randint(0, h - th) 78 | 79 | limg = limg.crop((x1, y1, x1 + tw, y1 + th)) 80 | rimg = rimg.crop((x1, y1, x1 + tw, y1 + th)) 81 | 82 | limg = self.transform(limg) 83 | rimg = self.transform(rimg) 84 | 85 | baseline, fx, fy = 0.532725, 725.0087, 725.0087 86 | camera_params = {'baseline': baseline, 87 | 'fx': fx, 88 | 'fy': fy} 89 | 90 | ldepth = np.ascontiguousarray(ldisp, dtype=np.float32) / 100. 91 | ldisp = baseline * fy / ldepth 92 | ldisp = ldisp[y1:y1 + th, x1:x1 + tw] 93 | 94 | return limg, rimg, ldisp, ldisp 95 | 96 | else: 97 | w, h = limg.size 98 | 99 | limg = limg.crop((w-1232, h-368, w, h)) 100 | rimg = rimg.crop((w-1232, h-368, w, h)) 101 | ldisp = ldisp.crop((w-1232, h-368, w, h)) 102 | 103 | limg = self.transform(limg) 104 | rimg = self.transform(rimg) 105 | 106 | baseline, fx, fy = 0.532725, 725.0087, 725.0087 107 | ldepth = np.ascontiguousarray(ldisp, dtype=np.float32) / 100. 108 | ldisp = baseline * fy / ldepth 109 | 110 | return limg, rimg, ldisp, ldisp 111 | 112 | def __len__(self): 113 | return len(self.left) 114 | 115 | 116 | if __name__ == '__main__': 117 | 118 | path = '/media/data2/Dataset/vKITTI2/' 119 | a, b, c, d, e, f = vkt_loader(path) 120 | print(len(a)) -------------------------------------------------------------------------------- /test_middlebury.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import transforms 5 | from torch.autograd import Variable 6 | from torch.autograd import grad as Grad 7 | import skimage.io 8 | import os 9 | import copy 10 | from collections import OrderedDict 11 | from tqdm import tqdm, trange 12 | from PIL import Image 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | import cv2 16 | import argparse 17 | 18 | from dataloader import middlebury_loader as mb 19 | from dataloader import readpfm as rp 20 | import networks.Aggregator as Agg 21 | import networks.U_net as un 22 | import networks.feature_extraction as FE 23 | 24 | 25 | parser = argparse.ArgumentParser(description='GraftNet') 26 | parser.add_argument('--no_cuda', action='store_true', default=False) 27 | parser.add_argument('--gpu_id', type=str, default='2') 28 | parser.add_argument('--seed', type=str, default=0) 29 | parser.add_argument('--resolution', type=str, default='H') 30 | parser.add_argument('--data_path', type=str, default='/media/data/dataset/MiddEval3-data-H/') 31 | parser.add_argument('--load_path', type=str, default='trained_models/checkpoint_final_10epoch.tar') 32 | parser.add_argument('--max_disp', type=int, default=192) 33 | args = parser.parse_args() 34 | 35 | if not args.no_cuda: 36 | os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID" 37 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 38 | cuda = torch.cuda.is_available() 39 | 40 | train_limg, train_rimg, train_gt, test_limg, test_rimg = mb.mb_loader(args.data_path, res=args.resolution) 41 | 42 | fe_model = FE.VGG_Feature(fixed_param=True).eval() 43 | adaptor = un.U_Net_v4(img_ch=256, output_ch=64).eval() 44 | agg_model = Agg.PSMAggregator(args.max_disp, udc=True).eval() 45 | 46 | if cuda: 47 | fe_model = nn.DataParallel(fe_model.cuda()) 48 | adaptor = nn.DataParallel(adaptor.cuda()) 49 | agg_model = nn.DataParallel(agg_model.cuda()) 50 | 51 | adaptor.load_state_dict(torch.load(args.load_path)['fa_net']) 52 | agg_model.load_state_dict(torch.load(args.load_path)['net']) 53 | 54 | 55 | def test_trainset(): 56 | op = 0 57 | mae = 0 58 | 59 | for i in trange(len(train_limg)): 60 | 61 | limg_path = train_limg[i] 62 | rimg_path = train_rimg[i] 63 | 64 | limg = Image.open(limg_path).convert('RGB') 65 | rimg = Image.open(rimg_path).convert('RGB') 66 | 67 | w, h = limg.size 68 | wi, hi = (w // 16 + 1) * 16, (h // 16 + 1) * 16 69 | 70 | limg = limg.crop((w - wi, h - hi, w, h)) 71 | rimg = rimg.crop((w - wi, h - hi, w, h)) 72 | 73 | limg_tensor = transforms.Compose([ 74 | transforms.ToTensor(), 75 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])(limg) 76 | rimg_tensor = transforms.Compose([ 77 | transforms.ToTensor(), 78 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])(rimg) 79 | limg_tensor = limg_tensor.unsqueeze(0).cuda() 80 | rimg_tensor = rimg_tensor.unsqueeze(0).cuda() 81 | 82 | with torch.no_grad(): 83 | left_fea = fe_model(limg_tensor) 84 | right_fea = fe_model(rimg_tensor) 85 | 86 | left_fea = adaptor(left_fea) 87 | right_fea = adaptor(right_fea) 88 | 89 | pred_disp = agg_model(left_fea, right_fea, limg_tensor, training=False) 90 | pred_disp = pred_disp[:, hi - h:, wi - w:] 91 | 92 | pred_np = pred_disp.squeeze().cpu().numpy() 93 | 94 | torch.cuda.empty_cache() 95 | 96 | disp_gt, _ = rp.readPFM(train_gt[i]) 97 | disp_gt = np.ascontiguousarray(disp_gt, dtype=np.float32) 98 | disp_gt[disp_gt == np.inf] = 0 99 | 100 | occ_mask = Image.open(train_gt[i].replace('disp0GT.pfm', 'mask0nocc.png')).convert('L') 101 | occ_mask = np.ascontiguousarray(occ_mask, dtype=np.float32) 102 | 103 | mask = (disp_gt <= 0) | (occ_mask != 255) | (disp_gt >= args.max_disp) 104 | # mask = (disp_gt <= 0) | (disp_gt >= maxdisp) 105 | 106 | error = np.abs(pred_np - disp_gt) 107 | error[mask] = 0 108 | 109 | if i in [6, 8, 9, 12, 14]: 110 | k = 1 111 | else: 112 | k = 1 113 | 114 | op += np.sum(error > 2.0) / (w * h - np.sum(mask)) * k 115 | mae += np.sum(error) / (w * h - np.sum(mask)) * k 116 | 117 | print(op / 15 * 100) 118 | print(mae / 15) 119 | 120 | 121 | if __name__ == '__main__': 122 | test_trainset() 123 | # test_testset() -------------------------------------------------------------------------------- /dataloader/middlebury_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from dataloader import readpfm as rp 4 | import torch.utils.data as data 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | import random 8 | 9 | 10 | def mb_loader(filepath, res): 11 | 12 | train_path = os.path.join(filepath, 'training' + res) 13 | test_path = os.path.join(filepath, 'test' + res) 14 | gt_path = train_path.replace('training' + res, 'Eval3_GT/training' + res) 15 | 16 | train_left = [] 17 | train_right = [] 18 | train_gt = [] 19 | 20 | for c in os.listdir(train_path): 21 | train_left.append(os.path.join(train_path, c, 'im0.png')) 22 | train_right.append(os.path.join(train_path, c, 'im1.png')) 23 | train_gt.append(os.path.join(gt_path, c, 'disp0GT.pfm')) 24 | 25 | test_left = [] 26 | test_right = [] 27 | for c in os.listdir(test_path): 28 | test_left.append(os.path.join(test_path, c, 'im0.png')) 29 | test_right.append(os.path.join(test_path, c, 'im1.png')) 30 | 31 | train_left = sorted(train_left) 32 | train_right = sorted(train_right) 33 | train_gt = sorted(train_gt) 34 | test_left = sorted(test_left) 35 | test_right = sorted(test_right) 36 | 37 | return train_left, train_right, train_gt, test_left, test_right 38 | 39 | 40 | def img_loader(path): 41 | return Image.open(path).convert('RGB') 42 | 43 | 44 | def disparity_loader(path): 45 | return rp.readPFM(path) 46 | 47 | 48 | class myDataset(data.Dataset): 49 | 50 | def __init__(self, left, right, left_disp, training, imgloader=img_loader, dploader = disparity_loader): 51 | self.left = left 52 | self.right = right 53 | self.disp_L = left_disp 54 | self.imgloader = imgloader 55 | self.dploader = dploader 56 | self.training = training 57 | self.img_transorm = transforms.Compose([ 58 | transforms.ToTensor(), 59 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 60 | 61 | def __getitem__(self, index): 62 | left = self.left[index] 63 | right = self.right[index] 64 | disp_L = self.disp_L[index] 65 | 66 | left_img = self.imgloader(left) 67 | right_img = self.imgloader(right) 68 | dataL, scaleL = self.dploader(disp_L) 69 | dataL = Image.fromarray(np.ascontiguousarray(dataL, dtype=np.float32)) 70 | 71 | if self.training: 72 | w, h = left_img.size 73 | 74 | # random resize 75 | s = np.random.uniform(0.95, 1.05, 1) 76 | rw, rh = np.round(w*s), np.round(h*s) 77 | left_img = left_img.resize((rw, rh), Image.NEAREST) 78 | right_img = right_img.resize((rw, rh), Image.NEAREST) 79 | dataL = dataL.resize((rw, rh), Image.NEAREST) 80 | dataL = Image.fromarray(np.array(dataL) * s) 81 | 82 | # random horizontal flip 83 | p = np.random.rand(1) 84 | if p >= 0.5: 85 | left_img = horizontal_flip(left_img) 86 | right_img = horizontal_flip(right_img) 87 | dataL = horizontal_flip(dataL) 88 | 89 | w, h = left_img.size 90 | tw, th = 320, 240 91 | x1 = random.randint(0, w - tw) 92 | y1 = random.randint(0, h - th) 93 | 94 | left_img = left_img.crop((x1, y1, x1+tw, y1+th)) 95 | right_img = right_img.crop((x1, y1, x1+tw, y1+th)) 96 | dataL = dataL.crop((x1, y1, x1+tw, y1+th)) 97 | 98 | left_img = self.img_transorm(left_img) 99 | right_img = self.img_transorm(right_img) 100 | 101 | dataL = np.array(dataL) 102 | return left_img, right_img, dataL 103 | 104 | else: 105 | w, h = left_img.size 106 | left_img = left_img.resize((w // 32 * 32, h // 32 * 32)) 107 | right_img = right_img.resize((w // 32 * 32, h // 32 * 32)) 108 | 109 | left_img = self.img_transorm(left_img) 110 | right_img = self.img_transorm(right_img) 111 | 112 | dataL = np.array(dataL) 113 | return left_img, right_img, dataL 114 | 115 | def __len__(self): 116 | return len(self.left) 117 | 118 | 119 | def horizontal_flip(img): 120 | img_np = np.array(img) 121 | img_np = np.flip(img_np, axis=1) 122 | img = Image.fromarray(img_np) 123 | return img 124 | 125 | 126 | if __name__ == '__main__': 127 | train_left, train_right, train_gt, _, _ = mb_loader('/media/data/dataset/MiddEval3-data-Q/', res='Q') 128 | H, W = 0, 0 129 | for l in train_right: 130 | left_img = Image.open(l).convert('RGB') 131 | h, w = left_img.size 132 | H += h 133 | W += w 134 | print(H / 15, W / 15) -------------------------------------------------------------------------------- /dataloader/KITTIloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import torchvision.transforms as transforms 4 | import os 5 | from PIL import Image 6 | import random 7 | import numpy as np 8 | 9 | 10 | IMG_EXTENSIONS= [ 11 | '.jpg', '.JPG', '.jpeg', '.JPEG', 12 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP' 13 | ] 14 | 15 | 16 | def is_image_file(filename): 17 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 18 | 19 | 20 | def kt_loader(filepath): 21 | 22 | left_path = os.path.join(filepath, 'image_2') 23 | right_path = os.path.join(filepath, 'image_3') 24 | displ_path = os.path.join(filepath, 'disp_occ_0') 25 | 26 | # total_name = sorted([name for name in os.listdir(left_path) if name.find('_10') > -1]) 27 | total_name = [name for name in os.listdir(left_path) if name.find('_10') > -1] 28 | train_name = total_name[:160] 29 | val_name = total_name[160:] 30 | 31 | train_left = [] 32 | train_right = [] 33 | train_displ = [] 34 | for name in train_name: 35 | train_left.append(os.path.join(left_path, name)) 36 | train_right.append(os.path.join(right_path, name)) 37 | train_displ.append(os.path.join(displ_path, name)) 38 | 39 | val_left = [] 40 | val_right = [] 41 | val_displ = [] 42 | for name in val_name: 43 | val_left.append(os.path.join(left_path, name)) 44 | val_right.append(os.path.join(right_path, name)) 45 | val_displ.append(os.path.join(displ_path, name)) 46 | 47 | return train_left, train_right, train_displ, val_left, val_right, val_displ 48 | 49 | 50 | def kt2012_loader(filepath): 51 | 52 | left_path = os.path.join(filepath, 'colored_0') 53 | right_path = os.path.join(filepath, 'colored_1') 54 | displ_path = os.path.join(filepath, 'disp_occ') 55 | 56 | total_name = sorted([name for name in os.listdir(left_path) if name.find('_10') > -1]) 57 | train_name = total_name[:160] 58 | val_name = total_name[160:] 59 | 60 | train_left = [] 61 | train_right = [] 62 | train_displ = [] 63 | for name in train_name: 64 | train_left.append(os.path.join(left_path, name)) 65 | train_right.append(os.path.join(right_path, name)) 66 | train_displ.append(os.path.join(displ_path, name)) 67 | 68 | val_left = [] 69 | val_right = [] 70 | val_displ = [] 71 | for name in val_name: 72 | val_left.append(os.path.join(left_path, name)) 73 | val_right.append(os.path.join(right_path, name)) 74 | val_displ.append(os.path.join(displ_path, name)) 75 | 76 | return train_left, train_right, train_displ, val_left, val_right, val_displ 77 | 78 | 79 | def img_loader(path): 80 | return Image.open(path).convert('RGB') 81 | 82 | 83 | def disparity_loader(path): 84 | return Image.open(path) 85 | 86 | 87 | class myDataset(data.Dataset): 88 | 89 | def __init__(self, left, right, left_disp, training, imgloader=img_loader, disploader=disparity_loader): 90 | self.left = left 91 | self.right = right 92 | self.left_disp = left_disp 93 | 94 | self.training = training 95 | self.imgloader = imgloader 96 | self.disploader = disploader 97 | 98 | self.transform = transforms.Compose([ 99 | transforms.ToTensor(), 100 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 101 | 102 | def __getitem__(self, index): 103 | left = self.left[index] 104 | right = self.right[index] 105 | left_disp = self.left_disp[index] 106 | 107 | limg = self.imgloader(left) 108 | rimg = self.imgloader(right) 109 | ldisp = self.disploader(left_disp) 110 | 111 | # W, H = limg.size 112 | # limg = limg.resize((960, 288)) 113 | # rimg = rimg.resize((960, 288)) 114 | # ldisp = ldisp.resize((960, 288), Image.NEAREST) 115 | 116 | if self.training: 117 | w, h = limg.size 118 | tw, th = 512, 256 119 | 120 | x1 = random.randint(0, w - tw) 121 | y1 = random.randint(0, h - th) 122 | 123 | limg = limg.crop((x1, y1, x1 + tw, y1 + th)) 124 | rimg = rimg.crop((x1, y1, x1 + tw, y1 + th)) 125 | ldisp = np.ascontiguousarray(ldisp, dtype=np.float32) / 256 126 | ldisp = ldisp[y1:y1 + th, x1:x1 + tw] 127 | 128 | limg = self.transform(limg) 129 | rimg = self.transform(rimg) 130 | 131 | else: 132 | w, h = limg.size 133 | 134 | limg = limg.crop((w-1232, h-368, w, h)) 135 | rimg = rimg.crop((w-1232, h-368, w, h)) 136 | ldisp = ldisp.crop((w-1232, h-368, w, h)) 137 | ldisp = np.ascontiguousarray(ldisp, dtype=np.float32)/256 138 | 139 | limg = self.transform(limg) 140 | rimg = self.transform(rimg) 141 | 142 | # ldisp = ldisp * (960/W) 143 | return limg, rimg, ldisp, ldisp 144 | 145 | def __len__(self): 146 | return len(self.left) 147 | 148 | -------------------------------------------------------------------------------- /train_baseline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.utils.data 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | import os 8 | import copy 9 | from tqdm import tqdm 10 | 11 | from dataloader import sceneflow_loader as sf 12 | import networks.submodule as sm 13 | import networks.U_net as un 14 | import networks.Aggregator as Agg 15 | import networks.feature_extraction as FE 16 | import loss_functions as lf 17 | 18 | 19 | parser = argparse.ArgumentParser(description='GraftNet') 20 | parser.add_argument('--no_cuda', action='store_true', default=False) 21 | parser.add_argument('--gpu_id', type=str, default='0, 1') 22 | parser.add_argument('--seed', type=str, default=0) 23 | parser.add_argument('--batch_size', type=int, default=6) 24 | parser.add_argument('--epoch', type=int, default=8) 25 | parser.add_argument('--data_path', type=str, default='/media/data/dataset/SceneFlow/') 26 | parser.add_argument('--save_path', type=str, default='trained_models/') 27 | parser.add_argument('--max_disp', type=int, default=192) 28 | parser.add_argument('--color_transform', action='store_true', default=False) 29 | args = parser.parse_args() 30 | 31 | if not args.no_cuda: 32 | os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID" 33 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 34 | cuda = torch.cuda.is_available() 35 | 36 | torch.manual_seed(args.seed) 37 | if cuda: 38 | torch.cuda.manual_seed(args.seed) 39 | 40 | 41 | all_limg, all_rimg, all_ldisp, all_rdisp, test_limg, test_rimg, test_ldisp, test_rdisp = sf.sf_loader(args.data_path) 42 | 43 | trainLoader = torch.utils.data.DataLoader( 44 | sf.myDataset(all_limg, all_rimg, all_ldisp, all_rdisp, training=True, color_transform=args.color_transform), 45 | batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=False) 46 | 47 | 48 | fe_model = sm.GwcFeature(out_c=64).train() 49 | model = Agg.PSMAggregator(args.max_disp, udc=True).train() 50 | 51 | if cuda: 52 | fe_model = nn.DataParallel(fe_model.cuda()) 53 | model = nn.DataParallel(model.cuda()) 54 | 55 | params = [ 56 | {'params': fe_model.parameters(), 'lr': 1e-3}, 57 | {'params': model.parameters(), 'lr': 1e-3}, 58 | ] 59 | optimizer = optim.Adam(params, lr=1e-3, betas=(0.9, 0.999)) 60 | 61 | 62 | def train(imgL, imgR, gt_left, gt_right): 63 | imgL = torch.FloatTensor(imgL) 64 | imgR = torch.FloatTensor(imgR) 65 | gt_left = torch.FloatTensor(gt_left) 66 | gt_right = torch.FloatTensor(gt_right) 67 | 68 | if cuda: 69 | imgL, imgR = imgL.cuda(), imgR.cuda() 70 | gt_left, gt_right = gt_left.cuda(), gt_right.cuda() 71 | 72 | optimizer.zero_grad() 73 | 74 | left_fea = fe_model(imgL) 75 | right_fea = fe_model(imgR) 76 | 77 | loss1, loss2 = model(left_fea, right_fea, gt_left, training=True) 78 | 79 | loss1 = torch.mean(loss1) 80 | loss2 = torch.mean(loss2) 81 | 82 | loss = 0.1 * loss1 + loss2 83 | 84 | loss.backward() 85 | optimizer.step() 86 | 87 | return loss1.item(), loss2.item() 88 | 89 | 90 | def adjust_learning_rate(optimizer, epoch): 91 | if epoch <= 10: 92 | lr = 0.001 93 | else: 94 | lr = 0.0001 95 | # print(lr) 96 | for param_group in optimizer.param_groups: 97 | param_group['lr'] = lr 98 | 99 | 100 | def main(): 101 | 102 | # start_total_time = time.time() 103 | start_epoch = 1 104 | 105 | # checkpoint = torch.load('trained_gwcAgg/checkpoint_5_v1.tar') 106 | # model.load_state_dict(checkpoint['net']) 107 | # optimizer.load_state_dict(checkpoint['optimizer']) 108 | # start_epoch = checkpoint['epoch'] + 1 109 | # new_dict = {} 110 | # for k, v in checkpoint['fe_net'].items(): 111 | # k = "module." + k 112 | # new_dict[k] = v 113 | # fe_model.load_state_dict(new_dict) 114 | # optimizer_fe.load_state_dict(checkpoint['fe_optimizer']) 115 | 116 | for epoch in range(start_epoch, args.epoch + start_epoch): 117 | print('This is %d-th epoch' % (epoch)) 118 | total_train_loss1 = 0 119 | total_train_loss2 = 0 120 | adjust_learning_rate(optimizer, epoch) 121 | 122 | for batch_id, (imgL, imgR, disp_L, disp_R) in enumerate(tqdm(trainLoader)): 123 | train_loss1, train_loss2 = train(imgL, imgR, disp_L, disp_R) 124 | total_train_loss1 += train_loss1 125 | total_train_loss2 += train_loss2 126 | avg_train_loss1 = total_train_loss1 / len(trainLoader) 127 | avg_train_loss2 = total_train_loss2 / len(trainLoader) 128 | print('Epoch %d average training loss1 = %.3f, average training loss2 = %.3f' % 129 | (epoch, avg_train_loss1, avg_train_loss2)) 130 | 131 | state = {'net': model.state_dict(), 132 | 'fe_net': fe_model.state_dict(), 133 | 'optimizer': optimizer.state_dict(), 134 | 'epoch': epoch} 135 | if not os.path.exists(args.save_path): 136 | os.mkdir(args.save_path) 137 | save_model_path = args.save_path + 'checkpoint_baseline_{}epoch.tar'.format(epoch) 138 | torch.save(state, save_model_path) 139 | 140 | torch.cuda.empty_cache() 141 | 142 | 143 | if __name__ == '__main__': 144 | main() 145 | 146 | -------------------------------------------------------------------------------- /train_adaptor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | import os 7 | import copy 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | import argparse 11 | 12 | from dataloader import sceneflow_loader as sf 13 | import networks.Aggregator as Agg 14 | import networks.U_net as un 15 | import networks.feature_extraction as FE 16 | import loss_functions as lf 17 | 18 | 19 | parser = argparse.ArgumentParser(description='GraftNet') 20 | parser.add_argument('--no_cuda', action='store_true', default=False) 21 | parser.add_argument('--gpu_id', type=str, default='0, 1') 22 | parser.add_argument('--seed', type=str, default=0) 23 | parser.add_argument('--batch_size', type=int, default=8) 24 | parser.add_argument('--epoch', type=int, default=1) 25 | parser.add_argument('--data_path', type=str, default='/media/data/dataset/SceneFlow/') 26 | parser.add_argument('--save_path', type=str, default='trained_models/') 27 | parser.add_argument('--load_path', type=str, default='trained_models/checkpoint_baseline_8epoch.tar') 28 | parser.add_argument('--max_disp', type=int, default=192) 29 | parser.add_argument('--color_transform', action='store_true', default=False) 30 | args = parser.parse_args() 31 | 32 | if not args.no_cuda: 33 | os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID" 34 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 35 | cuda = torch.cuda.is_available() 36 | 37 | torch.manual_seed(args.seed) 38 | if cuda: 39 | torch.cuda.manual_seed(args.seed) 40 | 41 | all_limg, all_rimg, all_ldisp, all_rdisp, test_limg, test_rimg, test_ldisp, test_rdisp = sf.sf_loader(args.data_path) 42 | 43 | trainLoader = torch.utils.data.DataLoader( 44 | sf.myDataset(all_limg, all_rimg, all_ldisp, all_rdisp, training=True, color_transform=args.color_transform), 45 | batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=False) 46 | 47 | 48 | fe_model = FE.VGG_Feature(fixed_param=True).eval() 49 | model = un.U_Net_v4(img_ch=256, output_ch=64).train() 50 | print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 51 | agg_model = Agg.PSMAggregator(args.max_disp, udc=True).eval() 52 | 53 | if cuda: 54 | fe_model = nn.DataParallel(fe_model.cuda()) 55 | model = nn.DataParallel(model.cuda()) 56 | agg_model = nn.DataParallel(agg_model.cuda()) 57 | 58 | agg_model.load_state_dict(torch.load(args.load_path)['net']) 59 | for p in agg_model.parameters(): 60 | p.requires_grad = False 61 | 62 | optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999)) 63 | 64 | 65 | def train(imgL, imgR, gt_left, gt_right): 66 | imgL = torch.FloatTensor(imgL) 67 | imgR = torch.FloatTensor(imgR) 68 | gt_left = torch.FloatTensor(gt_left) 69 | gt_right = torch.FloatTensor(gt_right) 70 | 71 | if cuda: 72 | imgL, imgR, gt_left, gt_right = imgL.cuda(), imgR.cuda(), gt_left.cuda(), gt_right.cuda() 73 | 74 | optimizer.zero_grad() 75 | 76 | with torch.no_grad(): 77 | left_fea = fe_model(imgL) 78 | right_fea = fe_model(imgR) 79 | 80 | agg_left_fea = model(left_fea) 81 | agg_right_fea = model(right_fea) 82 | 83 | loss1, loss2 = agg_model(agg_left_fea, agg_right_fea, gt_left, training=True) 84 | loss1 = torch.mean(loss1) 85 | loss2 = torch.mean(loss2) 86 | loss = 0.1 * loss1 + loss2 87 | # loss = loss1 88 | 89 | loss.backward() 90 | optimizer.step() 91 | 92 | return loss1.item(), loss2.item() 93 | 94 | 95 | def adjust_learning_rate(optimizer, epoch): 96 | if epoch <= 10: 97 | lr = 0.001 98 | else: 99 | lr = 0.0001 100 | # print(lr) 101 | for param_group in optimizer.param_groups: 102 | param_group['lr'] = lr 103 | 104 | 105 | def main(): 106 | 107 | # start_total_time = time.time() 108 | start_epoch = 1 109 | 110 | # checkpoint = torch.load('trained_ft_CA_8.12/checkpoint_3_DA.tar') 111 | # agg_model.load_state_dict(checkpoint['net']) 112 | # optimizer.load_state_dict(checkpoint['optimizer']) 113 | # start_epoch = checkpoint['epoch'] + 1 114 | 115 | for epoch in range(start_epoch, args.epoch + start_epoch): 116 | print('This is %d-th epoch' % (epoch)) 117 | total_train_loss1 = 0 118 | total_train_loss2 = 0 119 | adjust_learning_rate(optimizer, epoch) 120 | 121 | for batch_id, (imgL, imgR, disp_L, disp_R) in enumerate(tqdm(trainLoader)): 122 | train_loss1, train_loss2 = train(imgL, imgR, disp_L, disp_R) 123 | total_train_loss1 += train_loss1 124 | total_train_loss2 += train_loss2 125 | avg_train_loss1 = total_train_loss1 / len(trainLoader) 126 | avg_train_loss2 = total_train_loss2 / len(trainLoader) 127 | print('Epoch %d average training loss1 = %.3f, average training loss2 = %.3f' % 128 | (epoch, avg_train_loss1, avg_train_loss2)) 129 | 130 | state = {'fa_net': model.state_dict(), 131 | 'net': agg_model.state_dict(), 132 | 'optimizer': optimizer.state_dict(), 133 | 'epoch': epoch} 134 | if not os.path.exists(args.save_path): 135 | os.mkdir(args.save_path) 136 | save_model_path = args.save_path + 'checkpoint_adaptor_{}epoch.tar'.format(epoch) 137 | torch.save(state, save_model_path) 138 | 139 | torch.cuda.empty_cache() 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | 145 | -------------------------------------------------------------------------------- /retrain_CostAggregation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | import os 7 | import copy 8 | from tqdm import tqdm, trange 9 | import matplotlib.pyplot as plt 10 | import argparse 11 | 12 | from dataloader import sceneflow_loader as sf 13 | import networks.Aggregator as Agg 14 | import networks.submodule as sm 15 | import networks.U_net as un 16 | import networks.feature_extraction as FE 17 | import loss_functions as lf 18 | 19 | 20 | parser = argparse.ArgumentParser(description='GraftNet') 21 | parser.add_argument('--no_cuda', action='store_true', default=False) 22 | parser.add_argument('--gpu_id', type=str, default='0, 1') 23 | parser.add_argument('--seed', type=str, default=0) 24 | parser.add_argument('--batch_size', type=int, default=6) 25 | parser.add_argument('--epoch', type=int, default=10) 26 | parser.add_argument('--data_path', type=str, default='/media/data/dataset/SceneFlow/') 27 | parser.add_argument('--save_path', type=str, default='trained_models/') 28 | parser.add_argument('--load_path', type=str, default='trained_models/checkpoint_adaptor_1epoch.tar') 29 | parser.add_argument('--max_disp', type=int, default=192) 30 | parser.add_argument('--color_transform', action='store_true', default=False) 31 | args = parser.parse_args() 32 | 33 | if not args.no_cuda: 34 | os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID" 35 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 36 | cuda = torch.cuda.is_available() 37 | 38 | torch.manual_seed(args.seed) 39 | if cuda: 40 | torch.cuda.manual_seed(args.seed) 41 | 42 | all_limg, all_rimg, all_ldisp, all_rdisp, test_limg, test_rimg, test_ldisp, test_rdisp = sf.sf_loader(args.data_path) 43 | 44 | trainLoader = torch.utils.data.DataLoader( 45 | sf.myDataset(all_limg, all_rimg, all_ldisp, all_rdisp, training=True, color_transform=args.color_transform), 46 | batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=False) 47 | 48 | 49 | fe_model = FE.VGG_Feature(fixed_param=True).eval() 50 | adaptor = un.U_Net_v4(img_ch=256, output_ch=64).eval() 51 | model = Agg.PSMAggregator(args.max_disp, udc=True).train() 52 | print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 53 | 54 | if cuda: 55 | fe_model = nn.DataParallel(fe_model.cuda()) 56 | adaptor = nn.DataParallel(adaptor.cuda()) 57 | model = nn.DataParallel(model.cuda()) 58 | 59 | adaptor.load_state_dict(torch.load(args.load_path)['fa_net']) 60 | for p in adaptor.parameters(): 61 | p.requires_grad = False 62 | 63 | optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999)) 64 | 65 | 66 | def train(imgL, imgR, gt_left, gt_right): 67 | imgL = torch.FloatTensor(imgL) 68 | imgR = torch.FloatTensor(imgR) 69 | gt_left = torch.FloatTensor(gt_left) 70 | gt_right = torch.FloatTensor(gt_right) 71 | 72 | if cuda: 73 | imgL, imgR, gt_left, gt_right = imgL.cuda(), imgR.cuda(), gt_left.cuda(), gt_right.cuda() 74 | 75 | optimizer.zero_grad() 76 | 77 | with torch.no_grad(): 78 | left_fea = fe_model(imgL) 79 | right_fea = fe_model(imgR) 80 | 81 | left_fea = adaptor(left_fea) 82 | right_fea = adaptor(right_fea) 83 | 84 | loss1, loss2 = model(left_fea, right_fea, gt_left, training=True) 85 | loss1 = torch.mean(loss1) 86 | loss2 = torch.mean(loss2) 87 | loss = 0.1 * loss1 + loss2 88 | # loss = loss1 89 | 90 | loss.backward() 91 | optimizer.step() 92 | 93 | return loss1.item(), loss2.item() 94 | 95 | 96 | def adjust_learning_rate(optimizer, epoch): 97 | if epoch <= 5: 98 | lr = 0.001 99 | else: 100 | lr = 0.0001 101 | # print(lr) 102 | for param_group in optimizer.param_groups: 103 | param_group['lr'] = lr 104 | 105 | 106 | def main(): 107 | 108 | # start_total_time = time.time() 109 | start_epoch = 1 110 | 111 | # checkpoint = torch.load('trained_ft_costAgg/checkpoint_1_v4.tar') 112 | # CostAggregator.load_state_dict(checkpoint['net']) 113 | # optimizer.load_state_dict(checkpoint['optimizer']) 114 | # start_epoch = checkpoint['epoch'] + 1 115 | 116 | for epoch in range(start_epoch, args.epoch + start_epoch): 117 | print('This is %d-th epoch' % (epoch)) 118 | total_train_loss1 = 0 119 | total_train_loss2 = 0 120 | adjust_learning_rate(optimizer, epoch) 121 | # 122 | 123 | for batch_id, (imgL, imgR, disp_L, disp_R) in enumerate(tqdm(trainLoader)): 124 | train_loss1, train_loss2 = train(imgL, imgR, disp_L, disp_R) 125 | total_train_loss1 += train_loss1 126 | total_train_loss2 += train_loss2 127 | avg_train_loss1 = total_train_loss1 / len(trainLoader) 128 | avg_train_loss2 = total_train_loss2 / len(trainLoader) 129 | print('Epoch %d average training loss1 = %.3f, average training loss2 = %.3f' % 130 | (epoch, avg_train_loss1, avg_train_loss2)) 131 | 132 | state = {'fa_net': adaptor.state_dict(), 133 | 'net': model.state_dict(), 134 | 'optimizer': optimizer.state_dict(), 135 | 'epoch': epoch} 136 | if not os.path.exists(args.save_path): 137 | os.mkdir(args.save_path) 138 | save_model_path = args.save_path + 'checkpoint_final_{}epoch.tar'.format(epoch) 139 | torch.save(state, save_model_path) 140 | 141 | torch.cuda.empty_cache() 142 | 143 | 144 | if __name__ == '__main__': 145 | main() 146 | 147 | -------------------------------------------------------------------------------- /loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import cv2 6 | from PIL import Image 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | def disp2distribute(disp_gt, max_disp, b=2): 11 | disp_gt = disp_gt.unsqueeze(1) 12 | disp_range = torch.arange(0, max_disp).view(1, -1, 1, 1).float().cuda() 13 | gt_distribute = torch.exp(-torch.abs(disp_range - disp_gt) / b) 14 | gt_distribute = gt_distribute / (torch.sum(gt_distribute, dim=1, keepdim=True) + 1e-8) 15 | return gt_distribute 16 | 17 | 18 | def CEloss(disp_gt, max_disp, gt_distribute, pred_distribute): 19 | mask = (disp_gt > 0) & (disp_gt < max_disp) 20 | 21 | pred_distribute = torch.log(pred_distribute + 1e-8) 22 | 23 | ce_loss = torch.sum(-gt_distribute * pred_distribute, dim=1) 24 | ce_loss = torch.mean(ce_loss[mask]) 25 | return ce_loss 26 | 27 | 28 | def gradient_x(img): 29 | img = F.pad(img, (0, 1, 0, 0), mode="replicate") 30 | gx = img[:, :, :, :-1] - img[:, :, :, 1:] 31 | return gx 32 | 33 | 34 | def gradient_y(img): 35 | img = F.pad(img, (0, 0, 0, 1), mode="replicate") 36 | gy = img[:, :, :-1, :] - img[:, :, 1:, :] 37 | return gy 38 | 39 | 40 | def smooth_loss(img, disp): 41 | img_gx = gradient_x(img) 42 | img_gy = gradient_y(img) 43 | disp_gx = gradient_x(disp) 44 | disp_gy = gradient_y(disp) 45 | 46 | weight_x = torch.exp(-torch.mean(torch.abs(img_gx), dim=1, keepdim=True)) 47 | weight_y = torch.exp(-torch.mean(torch.abs(img_gy), dim=1, keepdim=True)) 48 | smoothness_x = torch.abs(disp_gx * weight_x) 49 | smoothness_y = torch.abs(disp_gy * weight_y) 50 | smoothness_loss = smoothness_x + smoothness_y 51 | 52 | return torch.mean(smoothness_loss) 53 | 54 | 55 | 56 | def occlusion_mask(left_disp, right_disp, threshold=1): 57 | # left_disp = left_disp.unsqueeze(1) 58 | # right_disp = right_disp.unsqueeze(1) 59 | 60 | B, _, H, W = left_disp.size() 61 | 62 | x_base = torch.linspace(0, 1, W).repeat(B, H, 1).type_as(right_disp) 63 | y_base = torch.linspace(0, 1, H).repeat(B, W, 1).transpose(1, 2).type_as(right_disp) 64 | 65 | flow_field = torch.stack((x_base - left_disp.squeeze(1) / W, y_base), dim=3) 66 | 67 | recon_left_disp = F.grid_sample(right_disp, 2 * flow_field - 1, mode='bilinear', padding_mode='zeros') 68 | 69 | lr_check = torch.abs(recon_left_disp - left_disp) 70 | mask = lr_check > threshold 71 | 72 | return mask 73 | 74 | 75 | def reconstruction(right, disp): 76 | b, _, h, w = right.size() 77 | 78 | x_base = torch.linspace(0, 1, w).repeat(b, h, 1).type_as(right) 79 | y_base = torch.linspace(0, 1, h).repeat(b, w, 1).transpose(1, 2).type_as(right) 80 | 81 | flow_field = torch.stack((x_base - disp / w, y_base), dim=3) 82 | 83 | recon_left = F.grid_sample(right, 2 * flow_field - 1, mode='bilinear', padding_mode='zeros') 84 | return recon_left 85 | 86 | 87 | def NT_Xent_loss(positive_simi, negative_simi, t): 88 | loss = torch.exp(positive_simi / t) / \ 89 | (torch.exp(positive_simi / t) + torch.sum(torch.exp(negative_simi / t), dim=4)) 90 | loss = -torch.log(loss + 1e-9) 91 | return loss 92 | 93 | 94 | class FeatureSimilarityLoss(nn.Module): 95 | def __init__(self, max_disp): 96 | super(FeatureSimilarityLoss, self).__init__() 97 | self.max_disp = max_disp 98 | self.m = 0.3 99 | self.nega_num = 1 100 | 101 | def forward(self, left_fea, right_fea, left_disp, right_disp): 102 | B, _, H, W = left_fea.size() 103 | 104 | down_disp = F.interpolate(left_disp, (H, W), mode='nearest') / 4. 105 | # down_img = F.interpolate(left_img, (H, W), mode='nearest') 106 | # down_img = torch.mean(down_img, dim=1, keepdim=True) 107 | 108 | # t_map = self.t_net(left_fea) 109 | 110 | # create negative samples 111 | random_offset = torch.rand(B, self.nega_num, H, W).cuda() * 2 + 1 112 | random_sign = torch.sign(torch.rand(B, self.nega_num, H, W).cuda() - 0.5) 113 | random_offset *= random_sign 114 | negative_disp = down_disp + random_offset 115 | 116 | positive_recon = reconstruction(right_fea, down_disp.squeeze(1)) 117 | negative_recon = [] 118 | for i in range(self.nega_num): 119 | negative_recon.append(reconstruction(right_fea, negative_disp[:, i])) 120 | negative_recon = torch.stack(negative_recon, dim=4) 121 | 122 | left_fea = F.normalize(left_fea, dim=1) 123 | positive_recon = F.normalize(positive_recon, dim=1) 124 | negative_recon = F.normalize(negative_recon, dim=1) 125 | 126 | positive_simi = (torch.sum(left_fea * positive_recon, dim=1, keepdim=True) + 1) / 2 127 | negative_simi = (torch.sum(left_fea.unsqueeze(4) * negative_recon, dim=1, keepdim=True) + 1) / 2 128 | 129 | judge_mat_p = torch.zeros_like(positive_simi) 130 | judge_mat_n = torch.zeros_like(negative_simi) 131 | if torch.sum(positive_simi < judge_mat_p) > 0 or torch.sum(negative_simi < judge_mat_n) > 0: 132 | print('cosine_simi < 0') 133 | 134 | # hinge loss 135 | # dist = self.m + negative_simi - positive_simi 136 | # criteria = torch.zeros_like(dist) 137 | # loss, _ = torch.max(torch.cat((dist, criteria), dim=1), dim=1, keepdim=True) 138 | 139 | # NT-Xent loss 140 | # loss = NT_Xent_loss(positive_simi, negative_simi, t=t_map) 141 | loss = NT_Xent_loss(positive_simi, negative_simi, t=0.2) 142 | 143 | # img_grad = torch.sqrt(gradient_x(down_img) ** 2 + gradient_y(down_img) ** 2) 144 | # weight = torch.exp(-img_grad) 145 | # loss = loss * weight 146 | 147 | occ_mask = occlusion_mask(left_disp, right_disp, threshold=1) 148 | occ_mask = F.interpolate(occ_mask.float(), (H, W), mode='nearest') 149 | valid_mask = (down_disp > 0) & (down_disp < self.max_disp // 4) & (occ_mask == 0) 150 | 151 | return torch.mean(loss[valid_mask]) 152 | 153 | 154 | def gram_matrix(feature): 155 | B, C, H, W = feature.size() 156 | feature = feature.view(B, C, H * W) 157 | feature_t = feature.transpose(1, 2) 158 | gram_m = torch.bmm(feature, feature_t) / (H * W) 159 | return gram_m 160 | 161 | 162 | def gram_matrix_v2(feature): 163 | B, C, H, W = feature.size() 164 | feature = feature.view(B * C, H * W) 165 | gram_m = torch.mm(feature, feature.t()) / (B * C * H * W) 166 | return gram_m 167 | 168 | 169 | if __name__ == '__main__': 170 | 171 | a = torch.rand(2, 256, 64, 128) 172 | b = torch.rand(2, 256, 64, 128) 173 | 174 | gram_a = gram_matrix(a) 175 | gram_b = gram_matrix(b) 176 | print(F.mse_loss(gram_a, gram_b)) 177 | 178 | ga_2 = gram_matrix_v2(a) 179 | gb_2 = gram_matrix_v2(b) 180 | print(F.mse_loss(ga_2, gb_2)) 181 | -------------------------------------------------------------------------------- /networks/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = [ 6 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 7 | 'vgg19_bn', 'vgg19', 8 | ] 9 | 10 | 11 | class VGG(nn.Module): 12 | 13 | def __init__(self, features, num_classes=1000, init_weights=True): 14 | super(VGG, self).__init__() 15 | self.features = features 16 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 17 | self.classifier = nn.Sequential( 18 | nn.Linear(512 * 7 * 7, 4096), 19 | nn.ReLU(True), 20 | nn.Dropout(), 21 | nn.Linear(4096, 4096), 22 | nn.ReLU(True), 23 | nn.Dropout(), 24 | nn.Linear(4096, num_classes), 25 | ) 26 | if init_weights: 27 | self._initialize_weights() 28 | 29 | def forward(self, x): 30 | x = self.features(x) 31 | x = self.avgpool(x) 32 | x = torch.flatten(x, 1) 33 | x = self.classifier(x) 34 | return x 35 | 36 | def _initialize_weights(self): 37 | for m in self.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 40 | if m.bias is not None: 41 | nn.init.constant_(m.bias, 0) 42 | elif isinstance(m, nn.BatchNorm2d): 43 | nn.init.constant_(m.weight, 1) 44 | nn.init.constant_(m.bias, 0) 45 | elif isinstance(m, nn.Linear): 46 | nn.init.normal_(m.weight, 0, 0.01) 47 | nn.init.constant_(m.bias, 0) 48 | 49 | 50 | def make_layers(cfg, batch_norm=False): 51 | layers = [] 52 | in_channels = 3 53 | for v in cfg: 54 | if v == 'M': 55 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 56 | else: 57 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 58 | if batch_norm: 59 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 60 | else: 61 | layers += [conv2d, nn.ReLU(inplace=True)] 62 | in_channels = v 63 | return nn.Sequential(*layers) 64 | 65 | 66 | cfgs = { 67 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 68 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 69 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 70 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 71 | } 72 | 73 | 74 | def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): 75 | if pretrained: 76 | kwargs['init_weights'] = False 77 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 78 | return model 79 | 80 | 81 | def vgg11(pretrained=False, progress=True, **kwargs): 82 | r"""VGG 11-layer model (configuration "A") from 83 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 84 | 85 | Args: 86 | pretrained (bool): If True, returns a model pre-trained on ImageNet 87 | progress (bool): If True, displays a progress bar of the download to stderr 88 | """ 89 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 90 | 91 | 92 | def vgg11_bn(pretrained=False, progress=True, **kwargs): 93 | r"""VGG 11-layer model (configuration "A") with batch normalization 94 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 95 | 96 | Args: 97 | pretrained (bool): If True, returns a model pre-trained on ImageNet 98 | progress (bool): If True, displays a progress bar of the download to stderr 99 | """ 100 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 101 | 102 | 103 | def vgg13(pretrained=False, progress=True, **kwargs): 104 | r"""VGG 13-layer model (configuration "B") 105 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 106 | 107 | Args: 108 | pretrained (bool): If True, returns a model pre-trained on ImageNet 109 | progress (bool): If True, displays a progress bar of the download to stderr 110 | """ 111 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 112 | 113 | 114 | def vgg13_bn(pretrained=False, progress=True, **kwargs): 115 | r"""VGG 13-layer model (configuration "B") with batch normalization 116 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 117 | 118 | Args: 119 | pretrained (bool): If True, returns a model pre-trained on ImageNet 120 | progress (bool): If True, displays a progress bar of the download to stderr 121 | """ 122 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 123 | 124 | 125 | def vgg16(pretrained=False, progress=True, **kwargs): 126 | r"""VGG 16-layer model (configuration "D") 127 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 128 | 129 | Args: 130 | pretrained (bool): If True, returns a model pre-trained on ImageNet 131 | progress (bool): If True, displays a progress bar of the download to stderr 132 | """ 133 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 134 | 135 | 136 | def vgg16_bn(pretrained=False, progress=True, **kwargs): 137 | r"""VGG 16-layer model (configuration "D") with batch normalization 138 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 139 | 140 | Args: 141 | pretrained (bool): If True, returns a model pre-trained on ImageNet 142 | progress (bool): If True, displays a progress bar of the download to stderr 143 | """ 144 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 145 | 146 | 147 | def vgg19(pretrained=False, progress=True, **kwargs): 148 | r"""VGG 19-layer model (configuration "E") 149 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 150 | 151 | Args: 152 | pretrained (bool): If True, returns a model pre-trained on ImageNet 153 | progress (bool): If True, displays a progress bar of the download to stderr 154 | """ 155 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 156 | 157 | 158 | def vgg19_bn(pretrained=False, progress=True, **kwargs): 159 | r"""VGG 19-layer model (configuration 'E') with batch normalization 160 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 161 | 162 | Args: 163 | pretrained (bool): If True, returns a model pre-trained on ImageNet 164 | progress (bool): If True, displays a progress bar of the download to stderr 165 | """ 166 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) 167 | -------------------------------------------------------------------------------- /networks/submodule.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.data 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | from torchvision import models 8 | import math 9 | import numpy as np 10 | import torchvision.transforms as transforms 11 | import PIL 12 | import os 13 | import matplotlib.pyplot as plt 14 | from networks.resnet import ResNet, Bottleneck, BasicBlock_Res 15 | 16 | 17 | def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation): 18 | 19 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad, dilation = dilation, bias=False), 20 | nn.BatchNorm2d(out_planes)) 21 | 22 | 23 | def convbn_3d(in_planes, out_planes, kernel_size, stride, pad): 24 | 25 | return nn.Sequential(nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, padding=pad, stride=stride,bias=False), 26 | nn.BatchNorm3d(out_planes)) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation): 32 | super(BasicBlock, self).__init__() 33 | 34 | self.conv1 = nn.Sequential(convbn(inplanes, planes, 3, stride, pad, dilation), 35 | nn.ReLU(inplace=True)) 36 | 37 | self.conv2 = convbn(planes, planes, 3, 1, pad, dilation) 38 | 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | out = self.conv1(x) 44 | out = self.conv2(out) 45 | 46 | if self.downsample is not None: 47 | x = self.downsample(x) 48 | 49 | out += x 50 | 51 | return out 52 | 53 | 54 | class DisparityRegression(nn.Module): 55 | 56 | def __init__(self, maxdisp, win_size): 57 | super(DisparityRegression, self).__init__() 58 | self.max_disp = maxdisp 59 | self.win_size = win_size 60 | 61 | def forward(self, x): 62 | disp = torch.arange(0, self.max_disp).view(1, -1, 1, 1).float().to(x.device) 63 | 64 | if self.win_size > 0: 65 | max_d = torch.argmax(x, dim=1, keepdim=True) 66 | d_value = [] 67 | prob_value = [] 68 | for d in range(-self.win_size, self.win_size + 1): 69 | index = max_d + d 70 | index[index < 0] = 0 71 | index[index > x.shape[1] - 1] = x.shape[1] - 1 72 | d_value.append(index) 73 | 74 | prob = torch.gather(x, dim=1, index=index) 75 | prob_value.append(prob) 76 | 77 | part_x = torch.cat(prob_value, dim=1) 78 | part_x = part_x / (torch.sum(part_x, dim=1, keepdim=True) + 1e-8) 79 | part_d = torch.cat(d_value, dim=1).float() 80 | out = torch.sum(part_x * part_d, dim=1) 81 | 82 | else: 83 | out = torch.sum(x * disp, 1) 84 | 85 | return out 86 | 87 | 88 | class GwcFeature(nn.Module): 89 | def __init__(self, out_c, fuse_mode='add'): 90 | super(GwcFeature, self).__init__() 91 | self.inplanes = 32 92 | self.fuse_mode = fuse_mode 93 | 94 | self.firstconv = nn.Sequential(convbn(3, 32, 3, 2, 1, 1), 95 | nn.ReLU(inplace=True), 96 | convbn(32, 32, 3, 1, 1, 1), 97 | nn.ReLU(inplace=True), 98 | convbn(32, 32, 3, 1, 1, 1), 99 | nn.ReLU(inplace=True)) 100 | 101 | self.layer1 = self._make_layer(BasicBlock, 32, 3, 1, 1, 1) 102 | self.layer2 = self._make_layer(BasicBlock, 64, 16, 2, 1, 1) 103 | self.layer3 = self._make_layer(BasicBlock, 128, 3, 1, 1, 1) 104 | self.layer4 = self._make_layer(BasicBlock, 128, 3, 1, 1, 2) 105 | 106 | if self.fuse_mode == 'cat': 107 | self.lastconv = nn.Sequential(convbn(320, 128, 3, 1, 1, 1), 108 | nn.ReLU(inplace=True), 109 | nn.Conv2d(128, out_c, kernel_size=1, padding=0, stride=1, bias=False)) 110 | elif self.fuse_mode == 'add': 111 | self.l1_conv = nn.Conv2d(32, out_c, 1, stride=1, padding=0, bias=False) 112 | self.l2_conv = nn.Conv2d(64, out_c, 1, stride=1, padding=0, bias=False) 113 | self.l4_conv = nn.Conv2d(128, out_c, 1, stride=1, padding=0, bias=False) 114 | elif self.fuse_mode == 'add_sa': 115 | self.l1_conv = nn.Conv2d(64, out_c, 1, stride=1, padding=0, bias=False) 116 | self.l4_conv = nn.Conv2d(64, out_c, 1, stride=1, padding=0, bias=False) 117 | self.sa = nn.Sequential(convbn(2 * out_c, 2 * out_c, 3, 1, 1, 1), 118 | nn.ReLU(inplace=True), 119 | nn.Conv2d(2 * out_c, 2, 3, stride=1, padding=1, bias=False)) 120 | 121 | def _make_layer(self, block, planes, blocks, stride, pad, dilation): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | nn.BatchNorm2d(planes * block.expansion),) 128 | 129 | layers = [] 130 | layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation)) 131 | self.inplanes = planes * block.expansion 132 | for i in range(1, blocks): 133 | layers.append(block(self.inplanes, planes,1,None,pad,dilation)) 134 | 135 | return nn.Sequential(*layers) 136 | 137 | def forward(self, x): 138 | output = self.firstconv(x) 139 | output_l1 = self.layer1(output) 140 | output_l2 = self.layer2(output_l1) 141 | output_l3 = self.layer3(output_l2) 142 | output_l4 = self.layer4(output_l3) 143 | 144 | output_l1 = F.interpolate(output_l1, (output_l4.size()[2], output_l4.size()[3]), 145 | mode='bilinear', align_corners=True) 146 | 147 | if self.fuse_mode == 'cat': 148 | cat_feature = torch.cat((output_l2, output_l3, output_l4), dim=1) 149 | output_feature = self.lastconv(cat_feature) 150 | elif self.fuse_mode == 'add': 151 | output_l1 = self.l1_conv(output_l1) 152 | output_l4 = self.l4_conv(output_l4) 153 | output_feature = output_l1 + output_l4 154 | elif self.fuse_mode == 'add_sa': 155 | output_l1 = self.l1_conv(output_l1) 156 | output_l4 = self.l4_conv(output_l4) 157 | 158 | attention_map = self.sa(torch.cat((output_l1, output_l4), dim=1)) 159 | attention_map = torch.sigmoid(attention_map) 160 | output_feature = output_l1 * attention_map[:, 0].unsqueeze(1) + \ 161 | output_l4 * attention_map[:, 1].unsqueeze(1) 162 | 163 | return output_feature 164 | 165 | 166 | -------------------------------------------------------------------------------- /dataloader/sceneflow_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from dataloader import readpfm as rp 4 | import dataloader.preprocess 5 | import torch.utils.data as data 6 | import torchvision.transforms as transforms 7 | import numpy as np 8 | import random 9 | 10 | IMG_EXTENSIONS= [ 11 | '.jpg', '.JPG', '.jpeg', '.JPEG', 12 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP' 13 | ] 14 | 15 | 16 | def is_image_file(filename): 17 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 18 | 19 | 20 | # filepath = '/media/data/LiuBiyang/SceneFlow/' 21 | def sf_loader(filepath): 22 | 23 | classes = [d for d in os.listdir(filepath) if os.path.isdir(os.path.join(filepath, d))] 24 | image = [img for img in classes if img.find('frames_cleanpass') > -1] 25 | disparity = [disp for disp in classes if disp.find('disparity') > -1] 26 | 27 | all_left_img = [] 28 | all_right_img = [] 29 | all_left_disp = [] 30 | all_right_disp = [] 31 | test_left_img = [] 32 | test_right_img = [] 33 | test_left_disp = [] 34 | test_right_disp = [] 35 | 36 | monkaa_img = filepath + [x for x in image if 'monkaa' in x][0] 37 | monkaa_disp = filepath + [x for x in disparity if 'monkaa' in x][0] 38 | monkaa_dir = os.listdir(monkaa_img) 39 | for dd in monkaa_dir: 40 | left_path = monkaa_img + '/' + dd + '/left/' 41 | right_path = monkaa_img + '/' + dd + '/right/' 42 | disp_path = monkaa_disp + '/' + dd + '/left/' 43 | rdisp_path = monkaa_disp + '/' + dd + '/right/' 44 | 45 | left_imgs = os.listdir(left_path) 46 | for img in left_imgs: 47 | img_path = os.path.join(left_path, img) 48 | if is_image_file(img_path): 49 | all_left_img.append(img_path) 50 | all_right_img.append(os.path.join(right_path, img)) 51 | all_left_disp.append(disp_path + img.split(".")[0] + '.pfm') 52 | all_right_disp.append(rdisp_path + img.split(".")[0] + '.pfm') 53 | 54 | flying_img = filepath + [x for x in image if 'flying' in x][0] 55 | flying_disp = filepath + [x for x in disparity if 'flying' in x][0] 56 | fimg_train = flying_img + '/TRAIN/' 57 | fimg_test = flying_img + '/TEST/' 58 | fdisp_train = flying_disp + '/TRAIN/' 59 | fdisp_test = flying_disp + '/TEST/' 60 | fsubdir = ['A', 'B', 'C'] 61 | 62 | for dd in fsubdir: 63 | imgs_path = fimg_train + dd + '/' 64 | disps_path = fdisp_train + dd + '/' 65 | imgs = os.listdir(imgs_path) 66 | for cc in imgs: 67 | left_path = imgs_path + cc + '/left/' 68 | right_path = imgs_path + cc + '/right/' 69 | disp_path = disps_path + cc + '/left/' 70 | rdisp_path = disps_path + cc + '/right/' 71 | 72 | left_imgs = os.listdir(left_path) 73 | for img in left_imgs: 74 | img_path = os.path.join(left_path, img) 75 | if is_image_file(img_path): 76 | all_left_img.append(img_path) 77 | all_right_img.append(os.path.join(right_path, img)) 78 | all_left_disp.append(disp_path + img.split(".")[0] + '.pfm') 79 | all_right_disp.append(rdisp_path + img.split(".")[0] + '.pfm') 80 | 81 | for dd in fsubdir: 82 | imgs_path = fimg_test + dd + '/' 83 | disps_path = fdisp_test + dd + '/' 84 | imgs = os.listdir(imgs_path) 85 | for cc in imgs: 86 | left_path = imgs_path + cc + '/left/' 87 | right_path = imgs_path + cc + '/right/' 88 | disp_path = disps_path + cc + '/left/' 89 | rdisp_path = disps_path + cc + '/right/' 90 | 91 | left_imgs = os.listdir(left_path) 92 | for img in left_imgs: 93 | img_path = os.path.join(left_path, img) 94 | if is_image_file(img_path): 95 | test_left_img.append(img_path) 96 | test_right_img.append(os.path.join(right_path, img)) 97 | test_left_disp.append(disp_path + img.split(".")[0] + '.pfm') 98 | test_right_disp.append(rdisp_path + img.split(".")[0] + '.pfm') 99 | 100 | driving_img = filepath + [x for x in image if 'driving' in x][0] 101 | driving_disp = filepath + [x for x in disparity if 'driving' in x][0] 102 | dsubdir1 = ['15mm_focallength', '35mm_focallength'] 103 | dsubdir2 = ['scene_backwards', 'scene_forwards'] 104 | dsubdir3 = ['fast', 'slow'] 105 | for d in dsubdir1: 106 | img_path1 = driving_img + '/' + d + '/' 107 | disp_path1 = driving_disp + '/' + d + '/' 108 | for dd in dsubdir2: 109 | img_path2 = img_path1 + dd + '/' 110 | disp_path2 = disp_path1 + dd + '/' 111 | for ddd in dsubdir3: 112 | img_path3 = img_path2 + ddd + '/' 113 | disp_path3 = disp_path2 + ddd + '/' 114 | 115 | left_path = img_path3 + 'left/' 116 | right_path = img_path3 + 'right/' 117 | disp_path = disp_path3 + 'left/' 118 | rdisp_path = disp_path3 + 'right/' 119 | 120 | left_imgs = os.listdir(left_path) 121 | for img in left_imgs: 122 | img_path = os.path.join(left_path, img) 123 | if is_image_file(img_path): 124 | all_left_img.append(img_path) 125 | all_right_img.append(os.path.join(right_path, img)) 126 | all_left_disp.append(disp_path + img.split(".")[0] + '.pfm') 127 | all_right_disp.append(rdisp_path + img.split(".")[0] + '.pfm') 128 | 129 | return all_left_img, all_right_img, all_left_disp, all_right_disp, \ 130 | test_left_img, test_right_img, test_left_disp, test_right_disp 131 | 132 | 133 | def img_loader(path): 134 | return Image.open(path).convert('RGB') 135 | 136 | 137 | def disparity_loader(path): 138 | return rp.readPFM(path) 139 | 140 | 141 | def random_transform(left_img, right_img): 142 | if np.random.rand(1) <= 0.2: 143 | left_img = transforms.Grayscale(num_output_channels=3)(left_img) 144 | right_img = transforms.Grayscale(num_output_channels=3)(right_img) 145 | else: 146 | left_img = transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.1)(left_img) 147 | right_img = transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.1)(right_img) 148 | return left_img, right_img 149 | 150 | 151 | class myDataset(data.Dataset): 152 | 153 | def __init__(self, left, right, left_disp, right_disp, training, imgloader=img_loader, dploader = disparity_loader, 154 | color_transform = False): 155 | self.left = left 156 | self.right = right 157 | self.disp_L = left_disp 158 | self.disp_R = right_disp 159 | self.imgloader = imgloader 160 | self.dploader = dploader 161 | self.training = training 162 | self.img_transorm = transforms.Compose([ 163 | transforms.ToTensor(), 164 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 165 | self.color_transform = color_transform 166 | 167 | def __getitem__(self, index): 168 | left = self.left[index] 169 | right = self.right[index] 170 | disp_L = self.disp_L[index] 171 | disp_R = self.disp_R[index] 172 | 173 | left_img = self.imgloader(left) 174 | right_img = self.imgloader(right) 175 | dataL, _ = self.dploader(disp_L) 176 | dataL = np.ascontiguousarray(dataL, dtype=np.float32) 177 | dataR, _ = self.dploader(disp_R) 178 | dataR = np.ascontiguousarray(dataR, dtype=np.float32) 179 | 180 | if self.training: 181 | w, h = left_img.size 182 | tw, th = 512, 256 183 | x1 = random.randint(0, w - tw) 184 | y1 = random.randint(0, h - th) 185 | 186 | left_img = left_img.crop((x1, y1, x1+tw, y1+th)) 187 | right_img = right_img.crop((x1, y1, x1+tw, y1+th)) 188 | dataL = dataL[y1:y1+th, x1:x1+tw] 189 | dataR = dataR[y1:y1+th, x1:x1+tw] 190 | 191 | if self.color_transform: 192 | left_img, right_img = random_transform(left_img, right_img) 193 | 194 | left_img = self.img_transorm(left_img) 195 | right_img = self.img_transorm(right_img) 196 | 197 | return left_img, right_img, dataL, dataR 198 | 199 | else: 200 | w, h = left_img.size 201 | left_img = left_img.crop((w-960, h-544, w, h)) 202 | right_img = right_img.crop((w-960, h-544, w, h)) 203 | 204 | left_img = self.img_transorm(left_img) 205 | right_img = self.img_transorm(right_img) 206 | 207 | dataL = Image.fromarray(dataL).crop((w-960, h-544, w, h)) 208 | dataL = np.ascontiguousarray(dataL) 209 | dataR = Image.fromarray(dataR).crop((w-960, h-544, w, h)) 210 | dataR = np.ascontiguousarray(dataR) 211 | 212 | return left_img, right_img, dataL, dataR 213 | 214 | def __len__(self): 215 | return len(self.left) 216 | 217 | 218 | 219 | 220 | 221 | -------------------------------------------------------------------------------- /networks/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 7 | """3x3 convolution with padding""" 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 9 | padding=dilation, groups=groups, bias=False, dilation=dilation) 10 | 11 | 12 | def conv1x1(in_planes, out_planes, stride=1): 13 | """1x1 convolution""" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 15 | 16 | 17 | class BasicBlock_Res(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 21 | base_width=64, dilation=1, norm_layer=None, use_relu=True): 22 | super(BasicBlock_Res, self).__init__() 23 | if norm_layer is None: 24 | norm_layer = nn.BatchNorm2d 25 | if groups != 1 or base_width != 64: 26 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 27 | if dilation > 1: 28 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 29 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = norm_layer(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = norm_layer(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | self.use_relu = use_relu 39 | 40 | def forward(self, x): 41 | identity = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | 55 | if self.use_relu: 56 | out = self.relu(out) 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 62 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 63 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 64 | # This variant is also known as ResNet V1.5 and improves accuracy according to 65 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 66 | 67 | expansion = 4 68 | 69 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 70 | base_width=64, dilation=1, norm_layer=None, use_relu=True): 71 | super(Bottleneck, self).__init__() 72 | if norm_layer is None: 73 | norm_layer = nn.BatchNorm2d 74 | width = int(planes * (base_width / 64.)) * groups 75 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 76 | self.conv1 = conv1x1(inplanes, width) 77 | self.bn1 = norm_layer(width) 78 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 79 | self.bn2 = norm_layer(width) 80 | self.conv3 = conv1x1(width, planes * self.expansion) 81 | self.bn3 = norm_layer(planes * self.expansion) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.downsample = downsample 84 | self.stride = stride 85 | 86 | self.use_relu = use_relu 87 | 88 | def forward(self, x): 89 | identity = x 90 | 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv2(out) 96 | out = self.bn2(out) 97 | out = self.relu(out) 98 | 99 | out = self.conv3(out) 100 | out = self.bn3(out) 101 | 102 | if self.downsample is not None: 103 | identity = self.downsample(x) 104 | 105 | out += identity 106 | 107 | if self.use_relu: 108 | out = self.relu(out) 109 | 110 | return out 111 | 112 | 113 | class ResNet(nn.Module): 114 | 115 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 116 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 117 | norm_layer=None): 118 | super(ResNet, self).__init__() 119 | if norm_layer is None: 120 | norm_layer = nn.BatchNorm2d 121 | self._norm_layer = norm_layer 122 | 123 | self.inplanes = 64 124 | self.dilation = 1 125 | if replace_stride_with_dilation is None: 126 | # each element in the tuple indicates if we should replace 127 | # the 2x2 stride with a dilated convolution instead 128 | replace_stride_with_dilation = [False, False, False] 129 | if len(replace_stride_with_dilation) != 3: 130 | raise ValueError("replace_stride_with_dilation should be None " 131 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 132 | self.groups = groups 133 | self.base_width = width_per_group 134 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 135 | bias=False) 136 | self.bn1 = norm_layer(self.inplanes) 137 | self.relu = nn.ReLU(inplace=True) 138 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 139 | self.layer1 = self._make_layer(block, 64, layers[0]) 140 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 141 | dilate=replace_stride_with_dilation[0]) 142 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 143 | dilate=replace_stride_with_dilation[1]) 144 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 145 | dilate=replace_stride_with_dilation[2]) 146 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 147 | 148 | # if DenseCl, comment this line 149 | self.fc = nn.Linear(512 * block.expansion, num_classes) 150 | 151 | for m in self.modules(): 152 | if isinstance(m, nn.Conv2d): 153 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 154 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 155 | nn.init.constant_(m.weight, 1) 156 | nn.init.constant_(m.bias, 0) 157 | 158 | # Zero-initialize the last BN in each residual branch, 159 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 160 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 161 | if zero_init_residual: 162 | for m in self.modules(): 163 | if isinstance(m, Bottleneck): 164 | nn.init.constant_(m.bn3.weight, 0) 165 | elif isinstance(m, BasicBlock_Res): 166 | nn.init.constant_(m.bn2.weight, 0) 167 | 168 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 169 | norm_layer = self._norm_layer 170 | downsample = None 171 | previous_dilation = self.dilation 172 | if dilate: 173 | self.dilation *= stride 174 | stride = 1 175 | 176 | # stride = 1 177 | 178 | if stride != 1 or self.inplanes != planes * block.expansion: 179 | downsample = nn.Sequential( 180 | conv1x1(self.inplanes, planes * block.expansion, stride), 181 | norm_layer(planes * block.expansion), 182 | ) 183 | 184 | layers = [] 185 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 186 | self.base_width, previous_dilation, norm_layer)) 187 | self.inplanes = planes * block.expansion 188 | for i in range(1, blocks): 189 | if i == blocks - 1: 190 | layers.append(block(self.inplanes, planes, groups=self.groups, 191 | base_width=self.base_width, dilation=self.dilation, 192 | norm_layer=norm_layer, use_relu=False)) 193 | else: 194 | layers.append(block(self.inplanes, planes, groups=self.groups, 195 | base_width=self.base_width, dilation=self.dilation, 196 | norm_layer=norm_layer, use_relu=True)) 197 | 198 | return nn.Sequential(*layers) 199 | 200 | def _forward_impl(self, x): 201 | # See note [TorchScript super()] 202 | x = self.conv1(x) 203 | x = self.bn1(x) 204 | x = self.relu(x) 205 | x = self.maxpool(x) 206 | x = self.layer1(x) 207 | 208 | # x = self.relu(x) 209 | 210 | # x = self.layer2(x) 211 | 212 | # x = self.relu(x) 213 | 214 | # x = self.layer3(x) 215 | # x = self.layer4(x) 216 | # 217 | # x = self.avgpool(x) 218 | # x = torch.flatten(x, 1) 219 | 220 | # x = self.fc(x) 221 | 222 | return x 223 | 224 | def forward(self, x): 225 | return self._forward_impl(x) 226 | 227 | -------------------------------------------------------------------------------- /networks/U_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class conv_block(nn.Module): 7 | def __init__(self, ch_in, ch_out): 8 | super(conv_block, self).__init__() 9 | self.conv = nn.Sequential( 10 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=False), 11 | nn.BatchNorm2d(ch_out), 12 | nn.ReLU(inplace=True), 13 | nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=False), 14 | nn.BatchNorm2d(ch_out), 15 | nn.ReLU(inplace=True) 16 | ) 17 | 18 | def forward(self, x): 19 | x = self.conv(x) 20 | return x 21 | 22 | 23 | class up_conv(nn.Module): 24 | def __init__(self, ch_in, ch_out): 25 | super(up_conv, self).__init__() 26 | self.up = nn.Sequential( 27 | nn.Upsample(scale_factor=2), 28 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 29 | nn.BatchNorm2d(ch_out), 30 | nn.ReLU(inplace=True) 31 | ) 32 | 33 | def forward(self, x): 34 | x = self.up(x) 35 | return x 36 | 37 | 38 | class U_Net(nn.Module): 39 | def __init__(self, img_ch=3, output_ch=1): 40 | super(U_Net, self).__init__() 41 | 42 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 43 | 44 | self.Conv1 = conv_block(ch_in=img_ch, ch_out=32) 45 | self.Conv2 = conv_block(ch_in=32, ch_out=64) 46 | self.Conv3 = conv_block(ch_in=64, ch_out=128) 47 | self.Conv4 = conv_block(ch_in=128, ch_out=256) 48 | # self.Conv5 = conv_block(ch_in=256, ch_out=512) 49 | self.Conv5 = conv_block(ch_in=256, ch_out=256) 50 | 51 | # self.Up5 = up_conv(ch_in=512, ch_out=256) 52 | self.Up5 = up_conv(ch_in=256, ch_out=256) 53 | self.Up_conv5 = conv_block(ch_in=512, ch_out=256) 54 | 55 | self.Up4 = up_conv(ch_in=256, ch_out=128) 56 | self.Up_conv4 = conv_block(ch_in=256, ch_out=128) 57 | 58 | self.Up3 = up_conv(ch_in=128, ch_out=64) 59 | self.Up_conv3 = conv_block(ch_in=128, ch_out=64) 60 | 61 | self.Up2 = up_conv(ch_in=64, ch_out=32) 62 | self.Up_conv2 = conv_block(ch_in=64, ch_out=32) 63 | 64 | self.Conv_1x1 = nn.Conv2d(32, output_ch, kernel_size=1, stride=1, padding=0) 65 | 66 | for m in self.modules(): 67 | if isinstance(m, nn.Conv2d): 68 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 69 | m.weight.data.normal_(0, math.sqrt(2. / n)) 70 | elif isinstance(m, nn.BatchNorm2d): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | elif isinstance(m, nn.Linear): 74 | m.bias.data.zero_() 75 | 76 | def forward(self, x): 77 | x1 = self.Conv1(x) 78 | 79 | x2 = self.Maxpool(x1) 80 | x2 = self.Conv2(x2) 81 | 82 | x3 = self.Maxpool(x2) 83 | x3 = self.Conv3(x3) 84 | 85 | x4 = self.Maxpool(x3) 86 | x4 = self.Conv4(x4) 87 | 88 | x5 = self.Maxpool(x4) 89 | x5 = self.Conv5(x5) 90 | 91 | d5 = self.Up5(x5) 92 | d5 = torch.cat((x4, d5), dim=1) 93 | d5 = self.Up_conv5(d5) 94 | 95 | d4 = self.Up4(d5) 96 | d4 = torch.cat((x3, d4), dim=1) 97 | d4 = self.Up_conv4(d4) 98 | 99 | d3 = self.Up3(d4) 100 | d3 = torch.cat((x2, d3), dim=1) 101 | d3 = self.Up_conv3(d3) 102 | 103 | d2 = self.Up2(d3) 104 | d2 = torch.cat((x1, d2), dim=1) 105 | d2 = self.Up_conv2(d2) 106 | 107 | d1 = self.Conv_1x1(d2) 108 | 109 | return d1 110 | 111 | 112 | class U_Net_v2(nn.Module): 113 | def __init__(self, img_ch=3, output_ch=1): 114 | super(U_Net_v2, self).__init__() 115 | 116 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 117 | 118 | self.Conv1 = conv_block(ch_in=img_ch, ch_out=32) 119 | self.Conv2 = conv_block(ch_in=32, ch_out=64) 120 | self.Conv3 = conv_block(ch_in=64, ch_out=128) 121 | self.Conv4 = conv_block(ch_in=128, ch_out=256) 122 | 123 | self.Up4 = up_conv(ch_in=256, ch_out=128) 124 | self.Up_conv4 = conv_block(ch_in=256, ch_out=128) 125 | 126 | self.Up3 = up_conv(ch_in=128, ch_out=64) 127 | self.Up_conv3 = conv_block(ch_in=128, ch_out=64) 128 | 129 | self.Up2 = up_conv(ch_in=64, ch_out=32) 130 | self.Up_conv2 = conv_block(ch_in=64, ch_out=32) 131 | 132 | self.Conv_1x1 = nn.Conv2d(32, output_ch, kernel_size=1, stride=1, padding=0) 133 | 134 | for m in self.modules(): 135 | if isinstance(m, nn.Conv2d): 136 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 137 | m.weight.data.normal_(0, math.sqrt(2. / n)) 138 | elif isinstance(m, nn.BatchNorm2d): 139 | m.weight.data.fill_(1) 140 | m.bias.data.zero_() 141 | elif isinstance(m, nn.Linear): 142 | m.bias.data.zero_() 143 | 144 | def forward(self, x): 145 | x1 = self.Conv1(x) 146 | 147 | x2 = self.Maxpool(x1) 148 | x2 = self.Conv2(x2) 149 | 150 | x3 = self.Maxpool(x2) 151 | x3 = self.Conv3(x3) 152 | 153 | x4 = self.Maxpool(x3) 154 | x4 = self.Conv4(x4) 155 | 156 | d4 = self.Up4(x4) 157 | d4 = torch.cat((x3, d4), dim=1) 158 | d4 = self.Up_conv4(d4) 159 | 160 | d3 = self.Up3(d4) 161 | d3 = torch.cat((x2, d3), dim=1) 162 | d3 = self.Up_conv3(d3) 163 | 164 | d2 = self.Up2(d3) 165 | d2 = torch.cat((x1, d2), dim=1) 166 | d2 = self.Up_conv2(d2) 167 | 168 | d1 = self.Conv_1x1(d2) 169 | 170 | return d1 171 | 172 | 173 | class U_Net_v3(nn.Module): 174 | def __init__(self, img_ch=3, output_ch=1): 175 | super(U_Net_v3, self).__init__() 176 | 177 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 178 | 179 | self.Conv0 = conv_block(ch_in=img_ch, ch_out=64) 180 | self.Conv1 = conv_block(ch_in=64, ch_out=128) 181 | self.Conv2 = conv_block(ch_in=128, ch_out=256) 182 | 183 | self.Up5 = up_conv(ch_in=256, ch_out=128) 184 | self.Up_conv5 = conv_block(ch_in=256, ch_out=128) 185 | 186 | self.Up4 = up_conv(ch_in=128, ch_out=64) 187 | self.Up_conv4 = conv_block(ch_in=128, ch_out=64) 188 | 189 | self.Up3 = up_conv(ch_in=64, ch_out=32) 190 | self.Up_conv3 = conv_block(ch_in=32, ch_out=32) 191 | 192 | self.Up2 = up_conv(ch_in=32, ch_out=32) 193 | self.Up_conv2 = conv_block(ch_in=32, ch_out=32) 194 | 195 | self.Conv_1x1 = nn.Conv2d(32, output_ch, kernel_size=1, stride=1, padding=0) 196 | 197 | for m in self.modules(): 198 | if isinstance(m, nn.Conv2d): 199 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 200 | m.weight.data.normal_(0, math.sqrt(2. / n)) 201 | elif isinstance(m, nn.BatchNorm2d): 202 | m.weight.data.fill_(1) 203 | m.bias.data.zero_() 204 | elif isinstance(m, nn.Linear): 205 | m.bias.data.zero_() 206 | 207 | def forward(self, x): 208 | x0 = self.Conv0(x) # 64 channels 209 | 210 | x1 = self.Conv1(x0) # 128 channels 211 | x1 = self.Maxpool(x1) # 1/8 resolution 212 | 213 | x2 = self.Conv2(x1) # 256 channels 214 | x2 = self.Maxpool(x2) # 1/16 resolution 215 | 216 | d4 = self.Up5(x2) # 1/8 resolution 217 | d4 = torch.cat((x1, d4), dim=1) 218 | d4 = self.Up_conv5(d4) # 128 channels 219 | 220 | d3 = self.Up4(d4) # 1/4 resolution 221 | d3 = torch.cat((x0, d3), dim=1) 222 | d3 = self.Up_conv4(d3) # 64 channels 223 | 224 | d2 = self.Up3(d3) # 1/2 resolution 225 | d2 = self.Up_conv3(d2) # 32 channels 226 | 227 | d1 = self.Up2(d2) # 1/2 resolution 228 | d1 = self.Up_conv2(d1) # 32 channels 229 | 230 | d0 = self.Conv_1x1(d1) 231 | 232 | return d0 233 | 234 | 235 | class U_Net_v4(nn.Module): 236 | def __init__(self, img_ch, output_ch): 237 | super(U_Net_v4, self).__init__() 238 | 239 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 240 | 241 | self.Conv1 = conv_block(ch_in=img_ch, ch_out=32) 242 | self.Conv2 = conv_block(ch_in=32, ch_out=64) 243 | self.Conv3 = conv_block(ch_in=64, ch_out=128) 244 | 245 | self.Conv4 = conv_block(ch_in=128, ch_out=128) 246 | 247 | self.Up4 = conv_block(ch_in=128, ch_out=128) 248 | self.Up_conv4 = up_conv(ch_in=256, ch_out=64) 249 | 250 | self.Up3 = conv_block(ch_in=64, ch_out=64) 251 | self.Up_conv3 = up_conv(ch_in=128, ch_out=32) 252 | 253 | self.last_conv = nn.Conv2d(64, output_ch, 1, 1, 0, 1) 254 | 255 | for m in self.modules(): 256 | if isinstance(m, nn.Conv2d): 257 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 258 | m.weight.data.normal_(0, math.sqrt(2. / n)) 259 | # nn.init.kaiming_normal_(m, mode='fan_in', nonlinearity='relu') 260 | elif isinstance(m, nn.BatchNorm2d): 261 | m.weight.data.fill_(1) 262 | m.bias.data.zero_() 263 | elif isinstance(m, nn.Linear): 264 | m.bias.data.zero_() 265 | 266 | def forward(self, x): 267 | x1 = self.Conv1(x) # 32, 1/4 268 | 269 | x2 = self.Maxpool(x1) 270 | x2 = self.Conv2(x2) # 64, 1/8 271 | 272 | x3 = self.Maxpool(x2) 273 | x3 = self.Conv3(x3) # 128, 1/16 274 | 275 | x4 = self.Conv4(x3) # 128, 1/16 276 | 277 | d4 = self.Up4(x4) # 128, 1/16 278 | d4 = torch.cat((x3, d4), dim=1) 279 | d4 = self.Up_conv4(d4) # 64, 1/8 280 | 281 | d3 = self.Up3(d4) # 64, 1/8 282 | d3 = torch.cat((x2, d3), dim=1) 283 | d3 = self.Up_conv3(d3) # 32, 1/4 284 | 285 | d2 = torch.cat((x1, d3), dim=1) 286 | d2 = self.last_conv(d2) 287 | 288 | return d2 289 | 290 | 291 | class LinearProj(nn.Module): 292 | def __init__(self, in_c, out_c): 293 | super(LinearProj, self).__init__() 294 | self.conv = nn.Sequential( 295 | nn.Conv2d(in_c, out_c, 1, 1, 0, 1), 296 | nn.ReLU(inplace=True), 297 | nn.Conv2d(out_c, out_c, 1, 1, 0, 1)) 298 | # self.conv = nn.Conv2d(in_c, out_c, 1, 1, 0, 1) 299 | 300 | def forward(self, x): 301 | x = self.conv(x) 302 | return x 303 | 304 | 305 | if __name__ == '__main__': 306 | a = torch.rand(2, 3, 64, 128).cuda() 307 | net = U_Net_v3(img_ch=3, output_ch=4).cuda() 308 | b = net(a) 309 | print(b.shape) -------------------------------------------------------------------------------- /networks/Aggregator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.data 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | import math 7 | from networks.submodule import convbn, convbn_3d, DisparityRegression 8 | from networks.stackhourglass import hourglass_gwcnet, hourglass 9 | import matplotlib.pyplot as plt 10 | import loss_functions as lf 11 | 12 | 13 | def build_cost_volume(left_fea, right_fea, max_disp, cost_type): 14 | if cost_type == 'cor': 15 | 16 | left_fea_norm = F.normalize(left_fea, dim=1) 17 | right_fea_norm = F.normalize(right_fea, dim=1) 18 | 19 | cost = torch.zeros(left_fea.size()[0], 1, max_disp // 4, 20 | left_fea.size()[2], left_fea.size()[3]).cuda() 21 | 22 | for i in range(max_disp // 4): 23 | if i > 0: 24 | cost[:, :, i, :, i:] = (torch.sum(left_fea_norm[:, :, :, i:] * right_fea_norm[:, :, :, :-i], 25 | dim=1, keepdim=True) + 1) / 2 26 | else: 27 | cost[:, :, i, :, :] = (torch.sum(left_fea_norm * right_fea_norm, dim=1, keepdim=True) + 1) / 2 28 | 29 | elif cost_type == 'l2': 30 | cost = torch.zeros(left_fea.size()[0], 1, max_disp // 4, 31 | left_fea.size()[2], left_fea.size()[3]).cuda() 32 | 33 | for i in range(max_disp // 4): 34 | if i > 0: 35 | cost[:, :, i, :, i:] = torch.sqrt(torch.sum( 36 | (left_fea[:, :, :, i:] - right_fea[:, :, :, :-i]) ** 2, dim=1, keepdim=True)) 37 | 38 | else: 39 | cost[:, :, i, :, :] = torch.sqrt(torch.sum((left_fea - right_fea) ** 2, dim=1, keepdim=True)) 40 | 41 | elif cost_type == 'cat': 42 | 43 | cost = torch.zeros(left_fea.size()[0], left_fea.size()[1] * 2, max_disp // 4, 44 | left_fea.size()[2], left_fea.size()[3]).cuda() 45 | 46 | for i in range(max_disp // 4): 47 | if i > 0: 48 | cost[:, :left_fea.size()[1], i, :, i:] = left_fea[:, :, :, i:] 49 | cost[:, left_fea.size()[1]:, i, :, i:] = right_fea[:, :, :, :-i] 50 | else: 51 | cost[:, :left_fea.size()[1], i, :, :] = left_fea 52 | cost[:, left_fea.size()[1]:, i, :, :] = right_fea 53 | 54 | elif cost_type == 'ncat': 55 | 56 | left_fea = F.normalize(left_fea, dim=1) 57 | right_fea = F.normalize(right_fea, dim=1) 58 | 59 | cost = torch.zeros(left_fea.size()[0], left_fea.size()[1] * 2, max_disp // 4, 60 | left_fea.size()[2], left_fea.size()[3]).cuda() 61 | 62 | for i in range(max_disp // 4): 63 | if i > 0: 64 | cost[:, :left_fea.size()[1], i, :, i:] = left_fea[:, :, :, i:] 65 | cost[:, left_fea.size()[1]:, i, :, i:] = right_fea[:, :, :, :-i] 66 | else: 67 | cost[:, :left_fea.size()[1], i, :, :] = left_fea 68 | cost[:, left_fea.size()[1]:, i, :, :] = right_fea 69 | 70 | cost = cost.contiguous() 71 | 72 | return cost 73 | 74 | 75 | class GwcAggregator(nn.Module): 76 | def __init__(self, maxdisp): 77 | super(GwcAggregator, self).__init__() 78 | self.maxdisp = maxdisp 79 | 80 | self.dres0 = nn.Sequential(convbn_3d(64, 32, 3, 1, 1), 81 | nn.ReLU(inplace=True), 82 | convbn_3d(32, 32, 3, 1, 1), 83 | nn.ReLU(inplace=True)) 84 | self.dres1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1), 85 | nn.ReLU(inplace=True), 86 | convbn_3d(32, 32, 3, 1, 1), 87 | nn.ReLU(inplace=True)) 88 | 89 | self.hg1 = hourglass_gwcnet(32) 90 | self.hg2 = hourglass_gwcnet(32) 91 | self.hg3 = hourglass_gwcnet(32) 92 | 93 | self.classify1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1), 94 | nn.ReLU(inplace=True), 95 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False)) 96 | self.classify2 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1), 97 | nn.ReLU(inplace=True), 98 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False)) 99 | self.classify3 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1), 100 | nn.ReLU(inplace=True), 101 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False)) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 106 | m.weight.data.normal_(0, math.sqrt(2. / n)) 107 | elif isinstance(m, nn.Conv3d): 108 | n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels 109 | m.weight.data.normal_(0, math.sqrt(2. / n)) 110 | elif isinstance(m, nn.BatchNorm2d): 111 | m.weight.data.fill_(1) 112 | m.bias.data.zero_() 113 | elif isinstance(m, nn.BatchNorm3d): 114 | m.weight.data.fill_(1) 115 | m.bias.data.zero_() 116 | elif isinstance(m, nn.Linear): 117 | m.bias.data.zero_() 118 | 119 | def forward(self, left_fea, right_fea, gt_left, gt_right): 120 | cost = build_cost_volume(left_fea, right_fea, self.maxdisp, cost_type='ncat') 121 | 122 | cost0 = self.dres0(cost) 123 | cost1 = self.dres1(cost0) + cost0 124 | 125 | out1 = self.hg1(cost1) 126 | out2 = self.hg2(out1) 127 | out3 = self.hg3(out2) 128 | 129 | win_s = 5 130 | 131 | if self.training: 132 | cost1 = self.classify1(out1) 133 | cost1 = F.interpolate(cost1, scale_factor=4, mode='trilinear', align_corners=True) 134 | cost1 = torch.squeeze(cost1, 1) 135 | distribute1 = F.softmax(cost1, dim=1) 136 | pred1 = DisparityRegression(self.maxdisp, win_size=win_s)(distribute1) 137 | 138 | cost2 = self.classify2(out2) 139 | cost2 = F.interpolate(cost2, scale_factor=4, mode='trilinear', align_corners=True) 140 | cost2 = torch.squeeze(cost2, 1) 141 | distribute2 = F.softmax(cost2, dim=1) 142 | pred2 = DisparityRegression(self.maxdisp, win_size=win_s)(distribute2) 143 | 144 | cost3 = self.classify3(out3) 145 | cost3 = F.interpolate(cost3, scale_factor=4, mode='trilinear', align_corners=True) 146 | cost3 = torch.squeeze(cost3, 1) 147 | distribute3 = F.softmax(cost3, dim=1) 148 | pred3 = DisparityRegression(self.maxdisp, win_size=win_s)(distribute3) 149 | 150 | if self.training: 151 | mask = (gt_left < self.maxdisp) & (gt_left > 0) 152 | loss1 = 0.5 * F.smooth_l1_loss(pred1[mask], gt_left[mask]) + \ 153 | 0.7 * F.smooth_l1_loss(pred2[mask], gt_left[mask]) + \ 154 | F.smooth_l1_loss(pred3[mask], gt_left[mask]) 155 | 156 | gt_distribute = lf.disp2distribute(gt_left, self.maxdisp, b=2) 157 | loss2 = 0.5 * lf.CEloss(gt_left, self.maxdisp, gt_distribute, distribute1) + \ 158 | 0.7 * lf.CEloss(gt_left, self.maxdisp, gt_distribute, distribute2) + \ 159 | lf.CEloss(gt_left, self.maxdisp, gt_distribute, distribute3) 160 | 161 | loss3 = lf.FeatureSimilarityLoss(self.maxdisp)(left_fea, right_fea, gt_left, gt_right) 162 | 163 | return loss1, loss2, loss3 164 | 165 | else: 166 | return pred3 167 | 168 | 169 | class PSMAggregator(nn.Module): 170 | def __init__(self, maxdisp, udc): 171 | super(PSMAggregator, self).__init__() 172 | self.maxdisp = maxdisp 173 | self.udc = udc 174 | 175 | self.dres0 = nn.Sequential(convbn_3d(1, 32, 3, 1, 1), 176 | nn.ReLU(inplace=True), 177 | convbn_3d(32, 32, 3, 1, 1), 178 | nn.ReLU(inplace=True)) 179 | self.dres1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1), 180 | nn.ReLU(inplace=True), 181 | convbn_3d(32, 32, 3, 1, 1), 182 | nn.ReLU(inplace=True)) 183 | 184 | self.hg1 = hourglass(32) 185 | self.hg2 = hourglass(32) 186 | self.hg3 = hourglass(32) 187 | 188 | self.classify1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1), 189 | nn.ReLU(inplace=True), 190 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False)) 191 | self.classify2 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1), 192 | nn.ReLU(inplace=True), 193 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False)) 194 | self.classify3 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1), 195 | nn.ReLU(inplace=True), 196 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False)) 197 | 198 | for m in self.modules(): 199 | if isinstance(m, nn.Conv2d): 200 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 201 | m.weight.data.normal_(0, math.sqrt(2. / n)) 202 | elif isinstance(m, nn.Conv3d): 203 | n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels 204 | m.weight.data.normal_(0, math.sqrt(2. / n)) 205 | elif isinstance(m, nn.BatchNorm2d): 206 | m.weight.data.fill_(1) 207 | m.bias.data.zero_() 208 | elif isinstance(m, nn.BatchNorm3d): 209 | m.weight.data.fill_(1) 210 | m.bias.data.zero_() 211 | elif isinstance(m, nn.Linear): 212 | m.bias.data.zero_() 213 | 214 | def forward(self, left_fea, right_fea, gt_left, training): 215 | cost = build_cost_volume(left_fea, right_fea, self.maxdisp, cost_type='cor') 216 | 217 | cost0 = self.dres0(cost) 218 | cost1 = self.dres1(cost0) + cost0 219 | 220 | out1, pre1, post1 = self.hg1(cost1, None, None) 221 | out1 = out1+cost0 222 | 223 | out2, pre2, post2 = self.hg2(out1, pre1, post1) 224 | out2 = out2+cost0 225 | 226 | out3, pre3, post3 = self.hg3(out2, pre1, post2) 227 | out3 = out3+cost0 228 | 229 | cost1 = self.classify1(out1) 230 | cost2 = self.classify2(out2) + cost1 231 | cost3 = self.classify3(out3) + cost2 232 | 233 | if self.udc: 234 | win_s = 5 235 | else: 236 | win_s = 0 237 | 238 | if self.training: 239 | cost1 = F.interpolate(cost1, scale_factor=4, mode='trilinear', align_corners=True) 240 | cost1 = torch.squeeze(cost1, 1) 241 | distribute1 = F.softmax(cost1, dim=1) 242 | pred1 = DisparityRegression(self.maxdisp, win_size=win_s)(distribute1) 243 | 244 | cost2 = F.interpolate(cost2, scale_factor=4, mode='trilinear', align_corners=True) 245 | cost2 = torch.squeeze(cost2, 1) 246 | distribute2 = F.softmax(cost2, dim=1) 247 | pred2 = DisparityRegression(self.maxdisp, win_size=win_s)(distribute2) 248 | 249 | cost3 = F.interpolate(cost3, scale_factor=4, mode='trilinear', align_corners=True) 250 | cost3 = torch.squeeze(cost3, 1) 251 | distribute3 = F.softmax(cost3, dim=1) 252 | pred3 = DisparityRegression(self.maxdisp, win_size=win_s)(distribute3) 253 | 254 | if self.training: 255 | mask = (gt_left < self.maxdisp) & (gt_left > 0) 256 | 257 | loss1 = 0.5 * F.smooth_l1_loss(pred1[mask], gt_left[mask]) + \ 258 | 0.7 * F.smooth_l1_loss(pred2[mask], gt_left[mask]) + \ 259 | F.smooth_l1_loss(pred3[mask], gt_left[mask]) 260 | 261 | gt_distribute = lf.disp2distribute(gt_left, self.maxdisp, b=2) 262 | loss2 = 0.5 * lf.CEloss(gt_left, self.maxdisp, gt_distribute, distribute1) + \ 263 | 0.7 * lf.CEloss(gt_left, self.maxdisp, gt_distribute, distribute2) + \ 264 | lf.CEloss(gt_left, self.maxdisp, gt_distribute, distribute3) 265 | return loss1, loss2 266 | 267 | else: 268 | if training: 269 | mask = (gt_left < self.maxdisp) & (gt_left > 0) 270 | loss1 = F.smooth_l1_loss(pred3[mask], gt_left[mask]) 271 | # loss2 = loss1 272 | gt_distribute = lf.disp2distribute(gt_left, self.maxdisp, b=2) 273 | loss2 = lf.CEloss(gt_left, self.maxdisp, gt_distribute, distribute3) 274 | return loss1, loss2 275 | 276 | else: 277 | return pred3 278 | --------------------------------------------------------------------------------