├── figure
└── figure.png
├── LICENSE
├── dataloader
├── readpfm.py
├── KITTI2012loader.py
├── ETH3D_loader.py
├── vKITTI_loader.py
├── middlebury_loader.py
├── KITTIloader.py
└── sceneflow_loader.py
├── README.md
├── networks
├── feature_extraction.py
├── stackhourglass.py
├── vgg.py
├── submodule.py
├── resnet.py
├── U_net.py
└── Aggregator.py
├── test_eth3d.py
├── test_kitti.py
├── test_middlebury.py
├── train_baseline.py
├── train_adaptor.py
├── retrain_CostAggregation.py
└── loss_functions.py
/figure/figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpadeLiu/Graft-PSMNet/HEAD/figure/figure.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 qqwweee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/dataloader/readpfm.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals
2 | import re
3 | import numpy as np
4 |
5 |
6 | def readPFM(file):
7 | file = open(file, 'rb')
8 |
9 | color = None
10 | width = None
11 | height = None
12 | scale = None
13 | endian = None
14 |
15 | header = file.readline().rstrip()
16 | header = header.decode('utf-8')
17 | if header == 'PF':
18 | color = True
19 | elif header == 'Pf':
20 | color = False
21 | else:
22 | raise Exception('Not a PFM file.')
23 |
24 | dim_match = re.match('^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
25 | if dim_match:
26 | width, height = map(int, dim_match.groups())
27 | else:
28 | raise Exception('Malformed PFM header.')
29 |
30 | scale = float(file.readline().rstrip().decode('utf-8'))
31 | if scale < 0:
32 | endian = '<'
33 | scale = -scale
34 | else:
35 | endian = '>'
36 |
37 | data = np.fromfile(file, endian + 'f')
38 | shape = (height, width, 3) if color else (height, width)
39 |
40 | data = np.reshape(data, shape)
41 | data = np.flipud(data)
42 |
43 | return data, scale
44 |
45 |
46 | if __name__ == '__main__':
47 | img_path = \
48 | '/media/data/dataset/SceneFlow/driving_frames_cleanpass/15mm_focallength/scene_backwards/fast/left/0100.png'
49 | disp_path = img_path.replace('driving_frames_cleanpass', 'driving_disparity').replace('png', 'pfm')
50 |
51 | data, scale = readPFM(disp_path)
52 | dataL = np.ascontiguousarray(data, dtype=np.float32)\
53 |
54 | import matplotlib.pyplot as plt
55 | plt.imshow(dataL)
56 | plt.show()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ### GraftNet: Towards Domain Generalized Stereo Matching with a Broad-Spectrum and Task-Oriented Feature
2 |
3 |
4 |
5 | #### Dependencies:
6 | - Python 3.6
7 | - PyTorch 1.7.0
8 | - torchvision 0.3.0
9 | - [VGG trained on ImageNet](https://download.pytorch.org/models/vgg16-397923af.pth)
10 |
11 | #### Datasets:
12 | - [SceneFlow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
13 | - [KITTI stereo 2015](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo)
14 | - [KITTI stereo 2012](http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php?benchmark=stereo)
15 | - [Middlebury v3](https://vision.middlebury.edu/stereo/submit3/)
16 | - [ETH3D](https://www.eth3d.net/datasets#low-res-two-view)
17 |
18 | #### Training Steps:
19 | ##### 1. Train A Basic Stereo Matching Network:
20 | ```bash
21 | python train_baseline.py --data_path (your SceneFlow data folder)
22 | ```
23 | ##### 2. Graft VGG's Feature and Train the Feature Adaptor:
24 | ```bash
25 | python train_adaptor.py --data_path (your SceneFlow data folder)
26 | ```
27 | ##### 3. Retrain the Cost Aggregation Module:
28 | ```bash
29 | python retrain_CostAggregation.py --data_path (your SceneFlow data folder)
30 | ```
31 |
32 | #### Evaluation:
33 | ##### Evaluate on KITTI:
34 | ```bash
35 | python test_kitti.py --data_path (your KITTI training data folder) --load_path (the path of the final model)
36 | ```
37 | ##### Evaluate on Middlebury-H:
38 | ```bash
39 | python test_middlebury.py --data_path (your Middlebury training data folder) --load_path (the path of the final model)
40 | ```
41 | ##### Evaluate on ETH3D:
42 | ```bash
43 | python test_middlebury.py --data_path (your Middlebury training data folder) --load_path (the path of the final model)
44 | ```
45 |
46 | #### Pretrained Models:
47 | [Google Drive](https://drive.google.com/drive/folders/1Ud9-HpHSXE5qMRQ17Fs8BNLfyE2VW03U?usp=sharing)
48 |
--------------------------------------------------------------------------------
/dataloader/KITTI2012loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import torchvision.transforms as transforms
3 | import os
4 | from PIL import Image
5 | import random
6 | import numpy as np
7 |
8 |
9 | IMG_EXTENSIONS= [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 |
19 | def kt2012_loader(filepath):
20 |
21 | left_path = os.path.join(filepath, 'colored_0')
22 | right_path = os.path.join(filepath, 'colored_1')
23 | displ_path = os.path.join(filepath, 'disp_occ')
24 |
25 | total_name = [name for name in os.listdir(left_path) if name.find('_10') > -1]
26 | train_name = total_name[:160]
27 | val_name = total_name[160:]
28 |
29 | train_left = []
30 | train_right = []
31 | train_displ = []
32 | for name in train_name:
33 | train_left.append(os.path.join(left_path, name))
34 | train_right.append(os.path.join(right_path, name))
35 | train_displ.append(os.path.join(displ_path, name))
36 |
37 | val_left = []
38 | val_right = []
39 | val_displ = []
40 | for name in val_name:
41 | val_left.append(os.path.join(left_path, name))
42 | val_right.append(os.path.join(right_path, name))
43 | val_displ.append(os.path.join(displ_path, name))
44 |
45 | return train_left, train_right, train_displ, val_left, val_right, val_displ
46 |
47 |
48 | def img_loader(path):
49 | return Image.open(path).convert('RGB')
50 |
51 |
52 | def disparity_loader(path):
53 | return Image.open(path)
54 |
55 |
56 | transform = transforms.Compose([
57 | transforms.ToTensor(),
58 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
59 | ])
60 |
61 |
62 | class myDataset(data.Dataset):
63 |
64 | def __init__(self, left, right, left_disp, training, imgloader=img_loader, disploader=disparity_loader):
65 | self.left = left
66 | self.right = right
67 | self.left_disp = left_disp
68 | self.imgloader = imgloader
69 | self.disploader = disploader
70 | self.training = training
71 |
72 | def __getitem__(self, index):
73 | left = self.left[index]
74 | right = self.right[index]
75 | left_disp = self.left_disp[index]
76 |
77 | limg = self.imgloader(left)
78 | rimg = self.imgloader(right)
79 | ldisp = self.disploader(left_disp)
80 |
81 | if self.training:
82 | w, h = limg.size
83 | tw, th = 512, 256
84 |
85 | x1 = random.randint(0, w - tw)
86 | y1 = random.randint(0, h - th)
87 |
88 | limg = limg.crop((x1, y1, x1 + tw, y1 + th))
89 | rimg = rimg.crop((x1, y1, x1 + tw, y1 + th))
90 | ldisp = np.ascontiguousarray(ldisp, dtype=np.float32)/256
91 | ldisp = ldisp[y1:y1 + th, x1:x1 + tw]
92 |
93 | limg = transform(limg)
94 | rimg = transform(rimg)
95 |
96 | return limg, rimg, ldisp
97 |
98 | else:
99 | w, h = limg.size
100 |
101 | limg = limg.crop((w-1232, h-368, w, h))
102 | rimg = rimg.crop((w-1232, h-368, w, h))
103 | ldisp = ldisp.crop((w-1232, h-368, w, h))
104 | ldisp = np.ascontiguousarray(ldisp, dtype=np.float32)/256
105 |
106 | limg = transform(limg)
107 | rimg = transform(rimg)
108 |
109 | return limg, rimg, ldisp
110 |
111 | def __len__(self):
112 | return len(self.left)
--------------------------------------------------------------------------------
/networks/feature_extraction.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.data
4 | from torch.autograd import Variable
5 | import torch.nn.functional as F
6 | from torchvision import models
7 | import math
8 | import numpy as np
9 | import torchvision.transforms as transforms
10 | import PIL
11 | import os
12 | import matplotlib.pyplot as plt
13 | from networks.resnet import ResNet, Bottleneck, BasicBlock_Res
14 | from networks.vgg import vgg16
15 | from collections import OrderedDict
16 |
17 |
18 | class VGG_Feature(nn.Module):
19 | def __init__(self, fixed_param):
20 | super(VGG_Feature, self).__init__()
21 |
22 | self.fe = vgg16(pretrained=False)
23 |
24 | self.fe.load_state_dict(
25 | torch.load('networks/vgg16-397923af.pth'))
26 |
27 | features = self.fe.features
28 |
29 | self.to_feat = nn.Sequential()
30 |
31 | for i in range(15):
32 | self.to_feat.add_module(str(i), features[i])
33 |
34 | if fixed_param:
35 | for p in self.to_feat.parameters():
36 | p.requires_grad = False
37 |
38 | def forward(self, x):
39 | feature = self.to_feat(x)
40 |
41 | # feature = F.interpolate(feature, scale_factor=0.5, mode='bilinear', align_corners=True)
42 |
43 | return feature
44 |
45 |
46 | class VGG_Bn_Feature(nn.Module):
47 | def __init__(self):
48 | super(VGG_Bn_Feature, self).__init__()
49 |
50 | features = models.vgg16_bn(pretrained=True).cuda().eval().features
51 | self.to_feat = nn.Sequential()
52 | # for i in range(8):
53 | # self.to_feat.add_module(str(i), features[i])
54 |
55 | for i in range(15):
56 | self.to_feat.add_module(str(i), features[i])
57 |
58 | for p in self.to_feat.parameters():
59 | p.requires_grad = False
60 |
61 | def forward(self, x):
62 | feature = self.to_feat(x)
63 |
64 | # feature = F.interpolate(feature, scale_factor=0.5, mode='bilinear', align_corners=True)
65 |
66 | return feature
67 |
68 |
69 | class Res18(nn.Module):
70 | def __init__(self):
71 | super(Res18, self).__init__()
72 |
73 | self.fe = ResNet(BasicBlock_Res, [2, 2, 2, 2])
74 |
75 | # self.fe = ResNet(Bottleneck, [3, 4, 6, 3])
76 |
77 | for p in self.fe.parameters():
78 | p.requires_grad = False
79 |
80 | self.fe.load_state_dict(
81 | torch.load('networks/resnet18-5c106cde.pth'))
82 |
83 | def forward(self, x):
84 |
85 | self.fe.eval()
86 |
87 | with torch.no_grad():
88 | feature = self.fe(x)
89 |
90 | return feature
91 |
92 |
93 | class Res50(nn.Module):
94 | def __init__(self):
95 | super(Res50, self).__init__()
96 |
97 | self.fe = ResNet(Bottleneck, [3, 4, 6, 3])
98 |
99 | for p in self.fe.parameters():
100 | p.requires_grad = False
101 |
102 | # self.fe.load_state_dict(
103 | # torch.load('networks/resnet50-19c8e357.pth'))
104 | self.fe.load_state_dict(
105 | torch.load('networks/DenseCL_R50_imagenet.pth'))
106 |
107 | def forward(self, x):
108 |
109 | self.fe.eval()
110 |
111 | with torch.no_grad():
112 | feature = self.fe(x)
113 |
114 | return feature
115 |
116 |
117 | if __name__ == '__main__':
118 | os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
119 | os.environ['CUDA_VISIBLE_DEVICES'] = "2"
120 | from collections import OrderedDict
121 | ckpt = torch.load('selfTrainVGG_withDA.pth')
122 | new_dict = OrderedDict()
123 | for k, v in ckpt.items():
124 | new_k = k.replace('module.', '')
125 | new_dict[new_k] = v
126 |
127 | torch.save(new_dict, 'selfTrainVGG_withDA.pth')
--------------------------------------------------------------------------------
/dataloader/ETH3D_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from dataloader import readpfm as rp
4 | import dataloader.preprocess
5 | import torch.utils.data as data
6 | import torchvision.transforms as transforms
7 | import numpy as np
8 | import random
9 |
10 | IMG_EXTENSIONS= [
11 | '.jpg', '.JPG', '.jpeg', '.JPEG',
12 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'
13 | ]
14 |
15 |
16 | def is_image_file(filename):
17 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
18 |
19 |
20 | # filepath = '/media/data/dataset/ETH3D/'
21 | def et_loader(filepath):
22 |
23 | left_img = []
24 | right_img = []
25 | disp_gt = []
26 | occ_mask = []
27 |
28 | img_path = os.path.join(filepath, 'two_view_training')
29 | gt_path = os.path.join(filepath, 'two_view_training_gt')
30 |
31 | for c in os.listdir(img_path):
32 | img_cpath = os.path.join(img_path, c)
33 | gt_cpath = os.path.join(gt_path, c)
34 |
35 | left_img.append(os.path.join(img_cpath, 'im0.png'))
36 | right_img.append(os.path.join(img_cpath, 'im1.png'))
37 | disp_gt.append(os.path.join(gt_cpath, 'disp0GT.pfm'))
38 | occ_mask.append(os.path.join(gt_cpath, 'mask0nocc.png'))
39 |
40 | return left_img, right_img, disp_gt, occ_mask,
41 |
42 |
43 | def img_loader(path):
44 | return Image.open(path).convert('RGB')
45 |
46 |
47 | def disparity_loader(path):
48 | return rp.readPFM(path)
49 |
50 |
51 | class myDataset(data.Dataset):
52 |
53 | def __init__(self, left, right, disp_gt, occ_mask, training, imgloader=img_loader, dploader = disparity_loader):
54 | self.left = left
55 | self.right = right
56 | self.disp_gt = disp_gt
57 | self.occ_mask = occ_mask
58 | self.imgloader = imgloader
59 | self.dploader = dploader
60 | self.training = training
61 | self.img_transorm = transforms.Compose([
62 | transforms.ToTensor(),
63 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
64 |
65 | def __getitem__(self, index):
66 | left = self.left[index]
67 | right = self.right[index]
68 | disp_L = self.disp_L[index]
69 | disp_R = self.disp_R[index]
70 |
71 | left_img = self.imgloader(left)
72 | right_img = self.imgloader(right)
73 | dataL, _ = self.dploader(disp_L)
74 | dataL = np.ascontiguousarray(dataL, dtype=np.float32)
75 | dataR, _ = self.dploader(disp_R)
76 | dataR = np.ascontiguousarray(dataR, dtype=np.float32)
77 |
78 | if self.training:
79 | w, h = left_img.size
80 | tw, th = 512, 256
81 | x1 = random.randint(0, w - tw)
82 | y1 = random.randint(0, h - th)
83 |
84 | left_img = left_img.crop((x1, y1, x1+tw, y1+th))
85 | right_img = right_img.crop((x1, y1, x1+tw, y1+th))
86 | dataL = dataL[y1:y1+th, x1:x1+tw]
87 | dataR = dataR[y1:y1+th, x1:x1+tw]
88 |
89 | left_img = self.img_transorm(left_img)
90 | right_img = self.img_transorm(right_img)
91 |
92 | return left_img, right_img, dataL, dataR
93 |
94 | else:
95 | w, h = left_img.size
96 | left_img = left_img.crop((w-960, h-544, w, h))
97 | right_img = right_img.crop((w-960, h-544, w, h))
98 |
99 | left_img = self.img_transorm(left_img)
100 | right_img = self.img_transorm(right_img)
101 |
102 | dataL = Image.fromarray(dataL).crop((w-960, h-544, w, h))
103 | dataL = np.ascontiguousarray(dataL)
104 | dataR = Image.fromarray(dataR).crop((w-960, h-544, w, h))
105 | dataR = np.ascontiguousarray(dataR)
106 |
107 | return left_img, right_img, dataL, dataR
108 |
109 | def __len__(self):
110 | return len(self.left)
111 |
112 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/test_eth3d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | from torch.autograd import grad as Grad
6 | from torchvision import transforms
7 | import os
8 | import copy
9 | import skimage.io
10 | from collections import OrderedDict
11 | from tqdm import tqdm, trange
12 | from PIL import Image
13 | import numpy as np
14 | import matplotlib.pyplot as plt
15 | import argparse
16 |
17 | from dataloader import ETH3D_loader as et
18 | from dataloader.readpfm import readPFM
19 | import networks.Aggregator as Agg
20 | import networks.feature_extraction as FE
21 | import networks.U_net as un
22 |
23 |
24 | parser = argparse.ArgumentParser(description='GraftNet')
25 | parser.add_argument('--no_cuda', action='store_true', default=False)
26 | parser.add_argument('--gpu_id', type=str, default='2')
27 | parser.add_argument('--seed', type=str, default=0)
28 | parser.add_argument('--data_path', type=str, default='/media/data/dataset/ETH3D/')
29 | parser.add_argument('--load_path', type=str, default='trained_models/checkpoint_final_10epoch.tar')
30 | parser.add_argument('--max_disp', type=int, default=192)
31 | args = parser.parse_args()
32 |
33 |
34 | if not args.no_cuda:
35 | os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
36 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
37 | cuda = torch.cuda.is_available()
38 |
39 |
40 | all_limg, all_rimg, all_disp, all_mask = et.et_loader(args.data_path)
41 |
42 |
43 | fe_model = FE.VGG_Feature(fixed_param=True).eval()
44 | adaptor = un.U_Net_v4(img_ch=256, output_ch=64).eval()
45 | agg_model = Agg.PSMAggregator(args.max_disp, udc=True).eval()
46 |
47 | if cuda:
48 | fe_model = nn.DataParallel(fe_model.cuda())
49 | adaptor = nn.DataParallel(adaptor.cuda())
50 | agg_model = nn.DataParallel(agg_model.cuda())
51 |
52 | adaptor.load_state_dict(torch.load(args.load_path)['fa_net'])
53 | agg_model.load_state_dict(torch.load(args.load_path)['net'])
54 |
55 |
56 | pred_mae = 0
57 | pred_op = 0
58 | for i in trange(len(all_limg)):
59 | limg = Image.open(all_limg[i]).convert('RGB')
60 | rimg = Image.open(all_rimg[i]).convert('RGB')
61 |
62 | w, h = limg.size
63 | wi, hi = (w // 16 + 1) * 16, (h // 16 + 1) * 16
64 | limg = limg.crop((w - wi, h - hi, w, h))
65 | rimg = rimg.crop((w - wi, h - hi, w, h))
66 |
67 | limg_tensor = transforms.Compose([
68 | transforms.ToTensor(),
69 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])(limg)
70 | rimg_tensor = transforms.Compose([
71 | transforms.ToTensor(),
72 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])(rimg)
73 | limg_tensor = limg_tensor.unsqueeze(0).cuda()
74 | rimg_tensor = rimg_tensor.unsqueeze(0).cuda()
75 |
76 | disp_gt, _ = readPFM(all_disp[i])
77 | disp_gt = np.ascontiguousarray(disp_gt, dtype=np.float32)
78 | disp_gt[disp_gt == np.inf] = 0
79 | gt_tensor = torch.FloatTensor(disp_gt).unsqueeze(0).unsqueeze(0).cuda()
80 |
81 | occ_mask = np.ascontiguousarray(Image.open(all_mask[i]))
82 |
83 | with torch.no_grad():
84 | left_fea = fe_model(limg_tensor)
85 | right_fea = fe_model(rimg_tensor)
86 |
87 | left_fea = adaptor(left_fea)
88 | right_fea = adaptor(right_fea)
89 |
90 | pred_disp = agg_model(left_fea, right_fea, gt_tensor, training=False)
91 | pred_disp = pred_disp[:, hi - h:, wi - w:]
92 |
93 | predict_np = pred_disp.squeeze().cpu().numpy()
94 |
95 | op_thresh = 1
96 | mask = (disp_gt > 0) & (occ_mask == 255)
97 | # mask = disp_gt > 0
98 | error = np.abs(predict_np * mask.astype(np.float32) - disp_gt * mask.astype(np.float32))
99 |
100 | pred_error = np.abs(predict_np * mask.astype(np.float32) - disp_gt * mask.astype(np.float32))
101 | pred_op += np.sum(pred_error > op_thresh) / np.sum(mask)
102 | pred_mae += np.mean(pred_error[mask])
103 |
104 | print(pred_mae / len(all_limg))
105 | print(pred_op / len(all_limg))
--------------------------------------------------------------------------------
/networks/stackhourglass.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import torch.nn as nn
4 | import torch.utils.data
5 | from torch.autograd import Variable
6 | import torch.nn.functional as F
7 | import math
8 | from networks.submodule import convbn, convbn_3d, DisparityRegression
9 |
10 |
11 | class hourglass(nn.Module):
12 | def __init__(self, inplanes):
13 | super(hourglass, self).__init__()
14 |
15 | self.conv1 = nn.Sequential(convbn_3d(inplanes, inplanes*2, kernel_size=3, stride=2, pad=1),
16 | nn.ReLU(inplace=True))
17 |
18 | self.conv2 = convbn_3d(inplanes*2, inplanes*2, kernel_size=3, stride=1, pad=1)
19 |
20 | self.conv3 = nn.Sequential(convbn_3d(inplanes*2, inplanes*2, kernel_size=3, stride=2, pad=1),
21 | nn.ReLU(inplace=True))
22 |
23 | self.conv4 = nn.Sequential(convbn_3d(inplanes*2, inplanes*2, kernel_size=3, stride=1, pad=1),
24 | nn.ReLU(inplace=True))
25 |
26 | self.conv5 = nn.Sequential(nn.ConvTranspose3d(inplanes*2, inplanes*2, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
27 | nn.BatchNorm3d(inplanes*2)) #+conv2
28 |
29 | self.conv6 = nn.Sequential(nn.ConvTranspose3d(inplanes*2, inplanes, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
30 | nn.BatchNorm3d(inplanes)) #+x
31 |
32 | def forward(self, x ,presqu, postsqu):
33 |
34 | out = self.conv1(x) #in:1/4 out:1/8
35 | pre = self.conv2(out) #in:1/8 out:1/8
36 |
37 | if postsqu is not None:
38 | pre = F.relu(pre + postsqu, inplace=True)
39 | else:
40 | pre = F.relu(pre, inplace=True)
41 |
42 | # print('pre2', pre.size())
43 |
44 | out = self.conv3(pre) #in:1/8 out:1/16
45 | out = self.conv4(out) #in:1/16 out:1/16
46 |
47 | # print('out', out.size())
48 |
49 | if presqu is not None:
50 | post = F.relu(self.conv5(out)+presqu, inplace=True) #in:1/16 out:1/8
51 | else:
52 | post = F.relu(self.conv5(out)+pre, inplace=True)
53 |
54 | out = self.conv6(post) #in:1/8 out:1/4
55 |
56 | return out, pre, post
57 |
58 |
59 | class hourglass_gwcnet(nn.Module):
60 | def __init__(self, inplanes):
61 | super(hourglass_gwcnet, self).__init__()
62 |
63 | self.conv1 = nn.Sequential(convbn_3d(inplanes, inplanes * 2, kernel_size=3, stride=2, pad=1),
64 | nn.ReLU(inplace=True))
65 | self.conv2 = nn.Sequential(convbn_3d(inplanes * 2, inplanes * 2, kernel_size=3, stride=1, pad=1),
66 | nn.ReLU(inplace=True))
67 | self.conv3 = nn.Sequential(convbn_3d(inplanes * 2, inplanes * 4, kernel_size=3, stride=2, pad=1),
68 | nn.ReLU(inplace=True))
69 | self.conv4 = nn.Sequential(convbn_3d(inplanes * 4, inplanes * 4, 3, 1, 1),
70 | nn.ReLU(inplace=True))
71 | self.conv5 = nn.Sequential(nn.ConvTranspose3d(inplanes * 4, inplanes * 2, kernel_size=3, padding=1,
72 | output_padding=1, stride=2, bias=False),
73 | nn.BatchNorm3d(inplanes * 2))
74 | self.conv6 = nn.Sequential(nn.ConvTranspose3d(inplanes * 2, inplanes, kernel_size=3, padding=1,
75 | output_padding=1, stride=2, bias=False),
76 | nn.BatchNorm3d(inplanes))
77 |
78 | self.redir1 = convbn_3d(inplanes, inplanes, kernel_size=1, stride=1, pad=0)
79 | self.redir2 = convbn_3d(inplanes * 2, inplanes * 2, kernel_size=1, stride=1, pad=0)
80 |
81 | def forward(self, x):
82 |
83 | conv1 = self.conv1(x)
84 | conv2 = self.conv2(conv1)
85 |
86 | conv3 = self.conv3(conv2)
87 | conv4 = self.conv4(conv3)
88 |
89 | conv5 = F.relu(self.conv5(conv4) + self.redir2(conv2), inplace=True)
90 | conv6 = F.relu(self.conv6(conv5) + self.redir1(x), inplace=True)
91 |
92 | return conv6
93 |
94 |
--------------------------------------------------------------------------------
/test_kitti.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | from torch.autograd import grad as Grad
6 | from torchvision import transforms
7 | import skimage.io
8 | import os
9 | import copy
10 | from collections import OrderedDict
11 | from tqdm import tqdm, trange
12 | from PIL import Image
13 | import numpy as np
14 | import matplotlib.pyplot as plt
15 | import argparse
16 |
17 | from dataloader import KITTIloader as kt
18 | from dataloader import KITTI2012loader as kt2012
19 | import networks.Aggregator as Agg
20 | import networks.feature_extraction as FE
21 | import networks.U_net as un
22 |
23 |
24 | parser = argparse.ArgumentParser(description='GraftNet')
25 | parser.add_argument('--no_cuda', action='store_true', default=False)
26 | parser.add_argument('--gpu_id', type=str, default='2')
27 | parser.add_argument('--seed', type=str, default=0)
28 | parser.add_argument('--kitti', type=str, default='2015')
29 | parser.add_argument('--data_path', type=str, default='/media/data/dataset/KITTI/data_scene_flow/training/')
30 | parser.add_argument('--load_path', type=str, default='trained_models/checkpoint_final_10epoch.tar')
31 | parser.add_argument('--max_disp', type=int, default=192)
32 | args = parser.parse_args()
33 |
34 | if not args.no_cuda:
35 | os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
36 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
37 | cuda = torch.cuda.is_available()
38 |
39 |
40 | if args.kitti == '2015':
41 | all_limg, all_rimg, all_ldisp, test_limg, test_rimg, test_ldisp = kt.kt_loader(args.data_path)
42 | else:
43 | all_limg, all_rimg, all_ldisp, test_limg, test_rimg, test_ldisp = kt2012.kt2012_loader(args.data_path)
44 |
45 | test_limg = all_limg + test_limg
46 | test_rimg = all_rimg + test_rimg
47 | test_ldisp = all_ldisp + test_ldisp
48 |
49 | fe_model = FE.VGG_Feature(fixed_param=True).eval()
50 | adaptor = un.U_Net_v4(img_ch=256, output_ch=64).eval()
51 | agg_model = Agg.PSMAggregator(args.max_disp, udc=True).eval()
52 |
53 | if cuda:
54 | fe_model = nn.DataParallel(fe_model.cuda())
55 | adaptor = nn.DataParallel(adaptor.cuda())
56 | agg_model = nn.DataParallel(agg_model.cuda())
57 |
58 | adaptor.load_state_dict(torch.load(args.load_path)['fa_net'])
59 | agg_model.load_state_dict(torch.load(args.load_path)['net'])
60 |
61 | pred_mae = 0
62 | pred_op = 0
63 | for i in trange(len(test_limg)):
64 | limg = Image.open(test_limg[i]).convert('RGB')
65 | rimg = Image.open(test_rimg[i]).convert('RGB')
66 |
67 | w, h = limg.size
68 | m = 16
69 | wi, hi = (w // m + 1) * m, (h // m + 1) * m
70 | limg = limg.crop((w - wi, h - hi, w, h))
71 | rimg = rimg.crop((w - wi, h - hi, w, h))
72 |
73 | transform = transforms.Compose([
74 | transforms.ToTensor(),
75 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
76 |
77 | limg_tensor = transform(limg)
78 | rimg_tensor = transform(rimg)
79 | limg_tensor = limg_tensor.unsqueeze(0).cuda()
80 | rimg_tensor = rimg_tensor.unsqueeze(0).cuda()
81 |
82 | disp_gt = Image.open(test_ldisp[i])
83 | disp_gt = np.ascontiguousarray(disp_gt, dtype=np.float32) / 256
84 | gt_tensor = torch.FloatTensor(disp_gt).unsqueeze(0).unsqueeze(0).cuda()
85 |
86 | with torch.no_grad():
87 | left_fea = fe_model(limg_tensor)
88 | right_fea = fe_model(rimg_tensor)
89 |
90 | left_fea = adaptor(left_fea)
91 | right_fea = adaptor(right_fea)
92 |
93 | pred_disp = agg_model(left_fea, right_fea, gt_tensor, training=False)
94 | pred_disp = pred_disp[:, hi - h:, wi - w:]
95 |
96 | predict_np = pred_disp.squeeze().cpu().numpy()
97 |
98 | op_thresh = 3
99 | mask = (disp_gt > 0) & (disp_gt < args.max_disp)
100 | error = np.abs(predict_np * mask.astype(np.float32) - disp_gt * mask.astype(np.float32))
101 |
102 | pred_error = np.abs(predict_np * mask.astype(np.float32) - disp_gt * mask.astype(np.float32))
103 | pred_op += np.sum((pred_error > op_thresh)) / np.sum(mask)
104 | pred_mae += np.mean(pred_error[mask])
105 |
106 | print(pred_mae / len(test_limg))
107 | print(pred_op / len(test_limg))
--------------------------------------------------------------------------------
/dataloader/vKITTI_loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import torchvision.transforms as transforms
3 | import os
4 | from PIL import Image
5 | import random
6 | import numpy as np
7 |
8 |
9 | def vkt_loader(filepath):
10 | all_limg = []
11 | all_rimg = []
12 | all_disp = []
13 |
14 | img_path = os.path.join(filepath, 'vkitti_2.0.3_rgb')
15 | depth_path = os.path.join(filepath, 'vkitti_2.0.3_depth')
16 |
17 | for scene in os.listdir(img_path):
18 | img_scenes_path = os.path.join(img_path, scene, 'clone/frames/rgb')
19 | depth_scenes_path = os.path.join(depth_path, scene, 'clone/frames/depth')
20 |
21 | for name in os.listdir(os.path.join(img_scenes_path, 'Camera_0')):
22 | all_limg.append(os.path.join(img_scenes_path, 'Camera_0', name))
23 | all_rimg.append(os.path.join(img_scenes_path, 'Camera_1', name))
24 | all_disp.append(os.path.join(depth_scenes_path, 'Camera_0',
25 | name.replace('jpg', 'png').replace('rgb', 'depth')))
26 |
27 | total_num = len(all_limg)
28 | train_length = int(total_num * 0.75)
29 |
30 | train_limg = all_limg[:train_length]
31 | train_rimg = all_rimg[:train_length]
32 | train_disp = all_disp[:train_length]
33 |
34 | val_limg = all_limg[train_length:]
35 | val_rimg = all_rimg[train_length:]
36 | val_disp = all_disp[train_length:]
37 |
38 | return train_limg, train_rimg, train_disp, val_limg, val_rimg, val_disp
39 |
40 |
41 | def img_loader(path):
42 | return Image.open(path).convert('RGB')
43 |
44 |
45 | def disparity_loader(path):
46 | return Image.open(path)
47 |
48 |
49 | class vkDataset(data.Dataset):
50 |
51 | def __init__(self, left, right, left_disp, training, imgloader=img_loader, disploader=disparity_loader):
52 | self.left = left
53 | self.right = right
54 | self.left_disp = left_disp
55 | self.imgloader = imgloader
56 | self.disploader = disploader
57 | self.training = training
58 | self.transform = transforms.Compose([
59 | transforms.ToTensor(),
60 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
61 | ])
62 |
63 | def __getitem__(self, index):
64 | left = self.left[index]
65 | right = self.right[index]
66 | left_disp = self.left_disp[index]
67 |
68 | limg = self.imgloader(left)
69 | rimg = self.imgloader(right)
70 | ldisp = self.disploader(left_disp)
71 |
72 | if self.training:
73 | w, h = limg.size
74 | tw, th = 512, 256
75 |
76 | x1 = random.randint(0, w - tw)
77 | y1 = random.randint(0, h - th)
78 |
79 | limg = limg.crop((x1, y1, x1 + tw, y1 + th))
80 | rimg = rimg.crop((x1, y1, x1 + tw, y1 + th))
81 |
82 | limg = self.transform(limg)
83 | rimg = self.transform(rimg)
84 |
85 | baseline, fx, fy = 0.532725, 725.0087, 725.0087
86 | camera_params = {'baseline': baseline,
87 | 'fx': fx,
88 | 'fy': fy}
89 |
90 | ldepth = np.ascontiguousarray(ldisp, dtype=np.float32) / 100.
91 | ldisp = baseline * fy / ldepth
92 | ldisp = ldisp[y1:y1 + th, x1:x1 + tw]
93 |
94 | return limg, rimg, ldisp, ldisp
95 |
96 | else:
97 | w, h = limg.size
98 |
99 | limg = limg.crop((w-1232, h-368, w, h))
100 | rimg = rimg.crop((w-1232, h-368, w, h))
101 | ldisp = ldisp.crop((w-1232, h-368, w, h))
102 |
103 | limg = self.transform(limg)
104 | rimg = self.transform(rimg)
105 |
106 | baseline, fx, fy = 0.532725, 725.0087, 725.0087
107 | ldepth = np.ascontiguousarray(ldisp, dtype=np.float32) / 100.
108 | ldisp = baseline * fy / ldepth
109 |
110 | return limg, rimg, ldisp, ldisp
111 |
112 | def __len__(self):
113 | return len(self.left)
114 |
115 |
116 | if __name__ == '__main__':
117 |
118 | path = '/media/data2/Dataset/vKITTI2/'
119 | a, b, c, d, e, f = vkt_loader(path)
120 | print(len(a))
--------------------------------------------------------------------------------
/test_middlebury.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision import transforms
5 | from torch.autograd import Variable
6 | from torch.autograd import grad as Grad
7 | import skimage.io
8 | import os
9 | import copy
10 | from collections import OrderedDict
11 | from tqdm import tqdm, trange
12 | from PIL import Image
13 | import numpy as np
14 | import matplotlib.pyplot as plt
15 | import cv2
16 | import argparse
17 |
18 | from dataloader import middlebury_loader as mb
19 | from dataloader import readpfm as rp
20 | import networks.Aggregator as Agg
21 | import networks.U_net as un
22 | import networks.feature_extraction as FE
23 |
24 |
25 | parser = argparse.ArgumentParser(description='GraftNet')
26 | parser.add_argument('--no_cuda', action='store_true', default=False)
27 | parser.add_argument('--gpu_id', type=str, default='2')
28 | parser.add_argument('--seed', type=str, default=0)
29 | parser.add_argument('--resolution', type=str, default='H')
30 | parser.add_argument('--data_path', type=str, default='/media/data/dataset/MiddEval3-data-H/')
31 | parser.add_argument('--load_path', type=str, default='trained_models/checkpoint_final_10epoch.tar')
32 | parser.add_argument('--max_disp', type=int, default=192)
33 | args = parser.parse_args()
34 |
35 | if not args.no_cuda:
36 | os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
37 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
38 | cuda = torch.cuda.is_available()
39 |
40 | train_limg, train_rimg, train_gt, test_limg, test_rimg = mb.mb_loader(args.data_path, res=args.resolution)
41 |
42 | fe_model = FE.VGG_Feature(fixed_param=True).eval()
43 | adaptor = un.U_Net_v4(img_ch=256, output_ch=64).eval()
44 | agg_model = Agg.PSMAggregator(args.max_disp, udc=True).eval()
45 |
46 | if cuda:
47 | fe_model = nn.DataParallel(fe_model.cuda())
48 | adaptor = nn.DataParallel(adaptor.cuda())
49 | agg_model = nn.DataParallel(agg_model.cuda())
50 |
51 | adaptor.load_state_dict(torch.load(args.load_path)['fa_net'])
52 | agg_model.load_state_dict(torch.load(args.load_path)['net'])
53 |
54 |
55 | def test_trainset():
56 | op = 0
57 | mae = 0
58 |
59 | for i in trange(len(train_limg)):
60 |
61 | limg_path = train_limg[i]
62 | rimg_path = train_rimg[i]
63 |
64 | limg = Image.open(limg_path).convert('RGB')
65 | rimg = Image.open(rimg_path).convert('RGB')
66 |
67 | w, h = limg.size
68 | wi, hi = (w // 16 + 1) * 16, (h // 16 + 1) * 16
69 |
70 | limg = limg.crop((w - wi, h - hi, w, h))
71 | rimg = rimg.crop((w - wi, h - hi, w, h))
72 |
73 | limg_tensor = transforms.Compose([
74 | transforms.ToTensor(),
75 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])(limg)
76 | rimg_tensor = transforms.Compose([
77 | transforms.ToTensor(),
78 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])(rimg)
79 | limg_tensor = limg_tensor.unsqueeze(0).cuda()
80 | rimg_tensor = rimg_tensor.unsqueeze(0).cuda()
81 |
82 | with torch.no_grad():
83 | left_fea = fe_model(limg_tensor)
84 | right_fea = fe_model(rimg_tensor)
85 |
86 | left_fea = adaptor(left_fea)
87 | right_fea = adaptor(right_fea)
88 |
89 | pred_disp = agg_model(left_fea, right_fea, limg_tensor, training=False)
90 | pred_disp = pred_disp[:, hi - h:, wi - w:]
91 |
92 | pred_np = pred_disp.squeeze().cpu().numpy()
93 |
94 | torch.cuda.empty_cache()
95 |
96 | disp_gt, _ = rp.readPFM(train_gt[i])
97 | disp_gt = np.ascontiguousarray(disp_gt, dtype=np.float32)
98 | disp_gt[disp_gt == np.inf] = 0
99 |
100 | occ_mask = Image.open(train_gt[i].replace('disp0GT.pfm', 'mask0nocc.png')).convert('L')
101 | occ_mask = np.ascontiguousarray(occ_mask, dtype=np.float32)
102 |
103 | mask = (disp_gt <= 0) | (occ_mask != 255) | (disp_gt >= args.max_disp)
104 | # mask = (disp_gt <= 0) | (disp_gt >= maxdisp)
105 |
106 | error = np.abs(pred_np - disp_gt)
107 | error[mask] = 0
108 |
109 | if i in [6, 8, 9, 12, 14]:
110 | k = 1
111 | else:
112 | k = 1
113 |
114 | op += np.sum(error > 2.0) / (w * h - np.sum(mask)) * k
115 | mae += np.sum(error) / (w * h - np.sum(mask)) * k
116 |
117 | print(op / 15 * 100)
118 | print(mae / 15)
119 |
120 |
121 | if __name__ == '__main__':
122 | test_trainset()
123 | # test_testset()
--------------------------------------------------------------------------------
/dataloader/middlebury_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from dataloader import readpfm as rp
4 | import torch.utils.data as data
5 | import torchvision.transforms as transforms
6 | import numpy as np
7 | import random
8 |
9 |
10 | def mb_loader(filepath, res):
11 |
12 | train_path = os.path.join(filepath, 'training' + res)
13 | test_path = os.path.join(filepath, 'test' + res)
14 | gt_path = train_path.replace('training' + res, 'Eval3_GT/training' + res)
15 |
16 | train_left = []
17 | train_right = []
18 | train_gt = []
19 |
20 | for c in os.listdir(train_path):
21 | train_left.append(os.path.join(train_path, c, 'im0.png'))
22 | train_right.append(os.path.join(train_path, c, 'im1.png'))
23 | train_gt.append(os.path.join(gt_path, c, 'disp0GT.pfm'))
24 |
25 | test_left = []
26 | test_right = []
27 | for c in os.listdir(test_path):
28 | test_left.append(os.path.join(test_path, c, 'im0.png'))
29 | test_right.append(os.path.join(test_path, c, 'im1.png'))
30 |
31 | train_left = sorted(train_left)
32 | train_right = sorted(train_right)
33 | train_gt = sorted(train_gt)
34 | test_left = sorted(test_left)
35 | test_right = sorted(test_right)
36 |
37 | return train_left, train_right, train_gt, test_left, test_right
38 |
39 |
40 | def img_loader(path):
41 | return Image.open(path).convert('RGB')
42 |
43 |
44 | def disparity_loader(path):
45 | return rp.readPFM(path)
46 |
47 |
48 | class myDataset(data.Dataset):
49 |
50 | def __init__(self, left, right, left_disp, training, imgloader=img_loader, dploader = disparity_loader):
51 | self.left = left
52 | self.right = right
53 | self.disp_L = left_disp
54 | self.imgloader = imgloader
55 | self.dploader = dploader
56 | self.training = training
57 | self.img_transorm = transforms.Compose([
58 | transforms.ToTensor(),
59 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
60 |
61 | def __getitem__(self, index):
62 | left = self.left[index]
63 | right = self.right[index]
64 | disp_L = self.disp_L[index]
65 |
66 | left_img = self.imgloader(left)
67 | right_img = self.imgloader(right)
68 | dataL, scaleL = self.dploader(disp_L)
69 | dataL = Image.fromarray(np.ascontiguousarray(dataL, dtype=np.float32))
70 |
71 | if self.training:
72 | w, h = left_img.size
73 |
74 | # random resize
75 | s = np.random.uniform(0.95, 1.05, 1)
76 | rw, rh = np.round(w*s), np.round(h*s)
77 | left_img = left_img.resize((rw, rh), Image.NEAREST)
78 | right_img = right_img.resize((rw, rh), Image.NEAREST)
79 | dataL = dataL.resize((rw, rh), Image.NEAREST)
80 | dataL = Image.fromarray(np.array(dataL) * s)
81 |
82 | # random horizontal flip
83 | p = np.random.rand(1)
84 | if p >= 0.5:
85 | left_img = horizontal_flip(left_img)
86 | right_img = horizontal_flip(right_img)
87 | dataL = horizontal_flip(dataL)
88 |
89 | w, h = left_img.size
90 | tw, th = 320, 240
91 | x1 = random.randint(0, w - tw)
92 | y1 = random.randint(0, h - th)
93 |
94 | left_img = left_img.crop((x1, y1, x1+tw, y1+th))
95 | right_img = right_img.crop((x1, y1, x1+tw, y1+th))
96 | dataL = dataL.crop((x1, y1, x1+tw, y1+th))
97 |
98 | left_img = self.img_transorm(left_img)
99 | right_img = self.img_transorm(right_img)
100 |
101 | dataL = np.array(dataL)
102 | return left_img, right_img, dataL
103 |
104 | else:
105 | w, h = left_img.size
106 | left_img = left_img.resize((w // 32 * 32, h // 32 * 32))
107 | right_img = right_img.resize((w // 32 * 32, h // 32 * 32))
108 |
109 | left_img = self.img_transorm(left_img)
110 | right_img = self.img_transorm(right_img)
111 |
112 | dataL = np.array(dataL)
113 | return left_img, right_img, dataL
114 |
115 | def __len__(self):
116 | return len(self.left)
117 |
118 |
119 | def horizontal_flip(img):
120 | img_np = np.array(img)
121 | img_np = np.flip(img_np, axis=1)
122 | img = Image.fromarray(img_np)
123 | return img
124 |
125 |
126 | if __name__ == '__main__':
127 | train_left, train_right, train_gt, _, _ = mb_loader('/media/data/dataset/MiddEval3-data-Q/', res='Q')
128 | H, W = 0, 0
129 | for l in train_right:
130 | left_img = Image.open(l).convert('RGB')
131 | h, w = left_img.size
132 | H += h
133 | W += w
134 | print(H / 15, W / 15)
--------------------------------------------------------------------------------
/dataloader/KITTIloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import torchvision.transforms as transforms
4 | import os
5 | from PIL import Image
6 | import random
7 | import numpy as np
8 |
9 |
10 | IMG_EXTENSIONS= [
11 | '.jpg', '.JPG', '.jpeg', '.JPEG',
12 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'
13 | ]
14 |
15 |
16 | def is_image_file(filename):
17 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
18 |
19 |
20 | def kt_loader(filepath):
21 |
22 | left_path = os.path.join(filepath, 'image_2')
23 | right_path = os.path.join(filepath, 'image_3')
24 | displ_path = os.path.join(filepath, 'disp_occ_0')
25 |
26 | # total_name = sorted([name for name in os.listdir(left_path) if name.find('_10') > -1])
27 | total_name = [name for name in os.listdir(left_path) if name.find('_10') > -1]
28 | train_name = total_name[:160]
29 | val_name = total_name[160:]
30 |
31 | train_left = []
32 | train_right = []
33 | train_displ = []
34 | for name in train_name:
35 | train_left.append(os.path.join(left_path, name))
36 | train_right.append(os.path.join(right_path, name))
37 | train_displ.append(os.path.join(displ_path, name))
38 |
39 | val_left = []
40 | val_right = []
41 | val_displ = []
42 | for name in val_name:
43 | val_left.append(os.path.join(left_path, name))
44 | val_right.append(os.path.join(right_path, name))
45 | val_displ.append(os.path.join(displ_path, name))
46 |
47 | return train_left, train_right, train_displ, val_left, val_right, val_displ
48 |
49 |
50 | def kt2012_loader(filepath):
51 |
52 | left_path = os.path.join(filepath, 'colored_0')
53 | right_path = os.path.join(filepath, 'colored_1')
54 | displ_path = os.path.join(filepath, 'disp_occ')
55 |
56 | total_name = sorted([name for name in os.listdir(left_path) if name.find('_10') > -1])
57 | train_name = total_name[:160]
58 | val_name = total_name[160:]
59 |
60 | train_left = []
61 | train_right = []
62 | train_displ = []
63 | for name in train_name:
64 | train_left.append(os.path.join(left_path, name))
65 | train_right.append(os.path.join(right_path, name))
66 | train_displ.append(os.path.join(displ_path, name))
67 |
68 | val_left = []
69 | val_right = []
70 | val_displ = []
71 | for name in val_name:
72 | val_left.append(os.path.join(left_path, name))
73 | val_right.append(os.path.join(right_path, name))
74 | val_displ.append(os.path.join(displ_path, name))
75 |
76 | return train_left, train_right, train_displ, val_left, val_right, val_displ
77 |
78 |
79 | def img_loader(path):
80 | return Image.open(path).convert('RGB')
81 |
82 |
83 | def disparity_loader(path):
84 | return Image.open(path)
85 |
86 |
87 | class myDataset(data.Dataset):
88 |
89 | def __init__(self, left, right, left_disp, training, imgloader=img_loader, disploader=disparity_loader):
90 | self.left = left
91 | self.right = right
92 | self.left_disp = left_disp
93 |
94 | self.training = training
95 | self.imgloader = imgloader
96 | self.disploader = disploader
97 |
98 | self.transform = transforms.Compose([
99 | transforms.ToTensor(),
100 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
101 |
102 | def __getitem__(self, index):
103 | left = self.left[index]
104 | right = self.right[index]
105 | left_disp = self.left_disp[index]
106 |
107 | limg = self.imgloader(left)
108 | rimg = self.imgloader(right)
109 | ldisp = self.disploader(left_disp)
110 |
111 | # W, H = limg.size
112 | # limg = limg.resize((960, 288))
113 | # rimg = rimg.resize((960, 288))
114 | # ldisp = ldisp.resize((960, 288), Image.NEAREST)
115 |
116 | if self.training:
117 | w, h = limg.size
118 | tw, th = 512, 256
119 |
120 | x1 = random.randint(0, w - tw)
121 | y1 = random.randint(0, h - th)
122 |
123 | limg = limg.crop((x1, y1, x1 + tw, y1 + th))
124 | rimg = rimg.crop((x1, y1, x1 + tw, y1 + th))
125 | ldisp = np.ascontiguousarray(ldisp, dtype=np.float32) / 256
126 | ldisp = ldisp[y1:y1 + th, x1:x1 + tw]
127 |
128 | limg = self.transform(limg)
129 | rimg = self.transform(rimg)
130 |
131 | else:
132 | w, h = limg.size
133 |
134 | limg = limg.crop((w-1232, h-368, w, h))
135 | rimg = rimg.crop((w-1232, h-368, w, h))
136 | ldisp = ldisp.crop((w-1232, h-368, w, h))
137 | ldisp = np.ascontiguousarray(ldisp, dtype=np.float32)/256
138 |
139 | limg = self.transform(limg)
140 | rimg = self.transform(rimg)
141 |
142 | # ldisp = ldisp * (960/W)
143 | return limg, rimg, ldisp, ldisp
144 |
145 | def __len__(self):
146 | return len(self.left)
147 |
148 |
--------------------------------------------------------------------------------
/train_baseline.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.utils.data
4 | import torch.optim as optim
5 | import torch.nn.functional as F
6 | import torch.nn as nn
7 | import os
8 | import copy
9 | from tqdm import tqdm
10 |
11 | from dataloader import sceneflow_loader as sf
12 | import networks.submodule as sm
13 | import networks.U_net as un
14 | import networks.Aggregator as Agg
15 | import networks.feature_extraction as FE
16 | import loss_functions as lf
17 |
18 |
19 | parser = argparse.ArgumentParser(description='GraftNet')
20 | parser.add_argument('--no_cuda', action='store_true', default=False)
21 | parser.add_argument('--gpu_id', type=str, default='0, 1')
22 | parser.add_argument('--seed', type=str, default=0)
23 | parser.add_argument('--batch_size', type=int, default=6)
24 | parser.add_argument('--epoch', type=int, default=8)
25 | parser.add_argument('--data_path', type=str, default='/media/data/dataset/SceneFlow/')
26 | parser.add_argument('--save_path', type=str, default='trained_models/')
27 | parser.add_argument('--max_disp', type=int, default=192)
28 | parser.add_argument('--color_transform', action='store_true', default=False)
29 | args = parser.parse_args()
30 |
31 | if not args.no_cuda:
32 | os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
33 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
34 | cuda = torch.cuda.is_available()
35 |
36 | torch.manual_seed(args.seed)
37 | if cuda:
38 | torch.cuda.manual_seed(args.seed)
39 |
40 |
41 | all_limg, all_rimg, all_ldisp, all_rdisp, test_limg, test_rimg, test_ldisp, test_rdisp = sf.sf_loader(args.data_path)
42 |
43 | trainLoader = torch.utils.data.DataLoader(
44 | sf.myDataset(all_limg, all_rimg, all_ldisp, all_rdisp, training=True, color_transform=args.color_transform),
45 | batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=False)
46 |
47 |
48 | fe_model = sm.GwcFeature(out_c=64).train()
49 | model = Agg.PSMAggregator(args.max_disp, udc=True).train()
50 |
51 | if cuda:
52 | fe_model = nn.DataParallel(fe_model.cuda())
53 | model = nn.DataParallel(model.cuda())
54 |
55 | params = [
56 | {'params': fe_model.parameters(), 'lr': 1e-3},
57 | {'params': model.parameters(), 'lr': 1e-3},
58 | ]
59 | optimizer = optim.Adam(params, lr=1e-3, betas=(0.9, 0.999))
60 |
61 |
62 | def train(imgL, imgR, gt_left, gt_right):
63 | imgL = torch.FloatTensor(imgL)
64 | imgR = torch.FloatTensor(imgR)
65 | gt_left = torch.FloatTensor(gt_left)
66 | gt_right = torch.FloatTensor(gt_right)
67 |
68 | if cuda:
69 | imgL, imgR = imgL.cuda(), imgR.cuda()
70 | gt_left, gt_right = gt_left.cuda(), gt_right.cuda()
71 |
72 | optimizer.zero_grad()
73 |
74 | left_fea = fe_model(imgL)
75 | right_fea = fe_model(imgR)
76 |
77 | loss1, loss2 = model(left_fea, right_fea, gt_left, training=True)
78 |
79 | loss1 = torch.mean(loss1)
80 | loss2 = torch.mean(loss2)
81 |
82 | loss = 0.1 * loss1 + loss2
83 |
84 | loss.backward()
85 | optimizer.step()
86 |
87 | return loss1.item(), loss2.item()
88 |
89 |
90 | def adjust_learning_rate(optimizer, epoch):
91 | if epoch <= 10:
92 | lr = 0.001
93 | else:
94 | lr = 0.0001
95 | # print(lr)
96 | for param_group in optimizer.param_groups:
97 | param_group['lr'] = lr
98 |
99 |
100 | def main():
101 |
102 | # start_total_time = time.time()
103 | start_epoch = 1
104 |
105 | # checkpoint = torch.load('trained_gwcAgg/checkpoint_5_v1.tar')
106 | # model.load_state_dict(checkpoint['net'])
107 | # optimizer.load_state_dict(checkpoint['optimizer'])
108 | # start_epoch = checkpoint['epoch'] + 1
109 | # new_dict = {}
110 | # for k, v in checkpoint['fe_net'].items():
111 | # k = "module." + k
112 | # new_dict[k] = v
113 | # fe_model.load_state_dict(new_dict)
114 | # optimizer_fe.load_state_dict(checkpoint['fe_optimizer'])
115 |
116 | for epoch in range(start_epoch, args.epoch + start_epoch):
117 | print('This is %d-th epoch' % (epoch))
118 | total_train_loss1 = 0
119 | total_train_loss2 = 0
120 | adjust_learning_rate(optimizer, epoch)
121 |
122 | for batch_id, (imgL, imgR, disp_L, disp_R) in enumerate(tqdm(trainLoader)):
123 | train_loss1, train_loss2 = train(imgL, imgR, disp_L, disp_R)
124 | total_train_loss1 += train_loss1
125 | total_train_loss2 += train_loss2
126 | avg_train_loss1 = total_train_loss1 / len(trainLoader)
127 | avg_train_loss2 = total_train_loss2 / len(trainLoader)
128 | print('Epoch %d average training loss1 = %.3f, average training loss2 = %.3f' %
129 | (epoch, avg_train_loss1, avg_train_loss2))
130 |
131 | state = {'net': model.state_dict(),
132 | 'fe_net': fe_model.state_dict(),
133 | 'optimizer': optimizer.state_dict(),
134 | 'epoch': epoch}
135 | if not os.path.exists(args.save_path):
136 | os.mkdir(args.save_path)
137 | save_model_path = args.save_path + 'checkpoint_baseline_{}epoch.tar'.format(epoch)
138 | torch.save(state, save_model_path)
139 |
140 | torch.cuda.empty_cache()
141 |
142 |
143 | if __name__ == '__main__':
144 | main()
145 |
146 |
--------------------------------------------------------------------------------
/train_adaptor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.nn as nn
6 | import os
7 | import copy
8 | from tqdm import tqdm
9 | import matplotlib.pyplot as plt
10 | import argparse
11 |
12 | from dataloader import sceneflow_loader as sf
13 | import networks.Aggregator as Agg
14 | import networks.U_net as un
15 | import networks.feature_extraction as FE
16 | import loss_functions as lf
17 |
18 |
19 | parser = argparse.ArgumentParser(description='GraftNet')
20 | parser.add_argument('--no_cuda', action='store_true', default=False)
21 | parser.add_argument('--gpu_id', type=str, default='0, 1')
22 | parser.add_argument('--seed', type=str, default=0)
23 | parser.add_argument('--batch_size', type=int, default=8)
24 | parser.add_argument('--epoch', type=int, default=1)
25 | parser.add_argument('--data_path', type=str, default='/media/data/dataset/SceneFlow/')
26 | parser.add_argument('--save_path', type=str, default='trained_models/')
27 | parser.add_argument('--load_path', type=str, default='trained_models/checkpoint_baseline_8epoch.tar')
28 | parser.add_argument('--max_disp', type=int, default=192)
29 | parser.add_argument('--color_transform', action='store_true', default=False)
30 | args = parser.parse_args()
31 |
32 | if not args.no_cuda:
33 | os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
34 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
35 | cuda = torch.cuda.is_available()
36 |
37 | torch.manual_seed(args.seed)
38 | if cuda:
39 | torch.cuda.manual_seed(args.seed)
40 |
41 | all_limg, all_rimg, all_ldisp, all_rdisp, test_limg, test_rimg, test_ldisp, test_rdisp = sf.sf_loader(args.data_path)
42 |
43 | trainLoader = torch.utils.data.DataLoader(
44 | sf.myDataset(all_limg, all_rimg, all_ldisp, all_rdisp, training=True, color_transform=args.color_transform),
45 | batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=False)
46 |
47 |
48 | fe_model = FE.VGG_Feature(fixed_param=True).eval()
49 | model = un.U_Net_v4(img_ch=256, output_ch=64).train()
50 | print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
51 | agg_model = Agg.PSMAggregator(args.max_disp, udc=True).eval()
52 |
53 | if cuda:
54 | fe_model = nn.DataParallel(fe_model.cuda())
55 | model = nn.DataParallel(model.cuda())
56 | agg_model = nn.DataParallel(agg_model.cuda())
57 |
58 | agg_model.load_state_dict(torch.load(args.load_path)['net'])
59 | for p in agg_model.parameters():
60 | p.requires_grad = False
61 |
62 | optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999))
63 |
64 |
65 | def train(imgL, imgR, gt_left, gt_right):
66 | imgL = torch.FloatTensor(imgL)
67 | imgR = torch.FloatTensor(imgR)
68 | gt_left = torch.FloatTensor(gt_left)
69 | gt_right = torch.FloatTensor(gt_right)
70 |
71 | if cuda:
72 | imgL, imgR, gt_left, gt_right = imgL.cuda(), imgR.cuda(), gt_left.cuda(), gt_right.cuda()
73 |
74 | optimizer.zero_grad()
75 |
76 | with torch.no_grad():
77 | left_fea = fe_model(imgL)
78 | right_fea = fe_model(imgR)
79 |
80 | agg_left_fea = model(left_fea)
81 | agg_right_fea = model(right_fea)
82 |
83 | loss1, loss2 = agg_model(agg_left_fea, agg_right_fea, gt_left, training=True)
84 | loss1 = torch.mean(loss1)
85 | loss2 = torch.mean(loss2)
86 | loss = 0.1 * loss1 + loss2
87 | # loss = loss1
88 |
89 | loss.backward()
90 | optimizer.step()
91 |
92 | return loss1.item(), loss2.item()
93 |
94 |
95 | def adjust_learning_rate(optimizer, epoch):
96 | if epoch <= 10:
97 | lr = 0.001
98 | else:
99 | lr = 0.0001
100 | # print(lr)
101 | for param_group in optimizer.param_groups:
102 | param_group['lr'] = lr
103 |
104 |
105 | def main():
106 |
107 | # start_total_time = time.time()
108 | start_epoch = 1
109 |
110 | # checkpoint = torch.load('trained_ft_CA_8.12/checkpoint_3_DA.tar')
111 | # agg_model.load_state_dict(checkpoint['net'])
112 | # optimizer.load_state_dict(checkpoint['optimizer'])
113 | # start_epoch = checkpoint['epoch'] + 1
114 |
115 | for epoch in range(start_epoch, args.epoch + start_epoch):
116 | print('This is %d-th epoch' % (epoch))
117 | total_train_loss1 = 0
118 | total_train_loss2 = 0
119 | adjust_learning_rate(optimizer, epoch)
120 |
121 | for batch_id, (imgL, imgR, disp_L, disp_R) in enumerate(tqdm(trainLoader)):
122 | train_loss1, train_loss2 = train(imgL, imgR, disp_L, disp_R)
123 | total_train_loss1 += train_loss1
124 | total_train_loss2 += train_loss2
125 | avg_train_loss1 = total_train_loss1 / len(trainLoader)
126 | avg_train_loss2 = total_train_loss2 / len(trainLoader)
127 | print('Epoch %d average training loss1 = %.3f, average training loss2 = %.3f' %
128 | (epoch, avg_train_loss1, avg_train_loss2))
129 |
130 | state = {'fa_net': model.state_dict(),
131 | 'net': agg_model.state_dict(),
132 | 'optimizer': optimizer.state_dict(),
133 | 'epoch': epoch}
134 | if not os.path.exists(args.save_path):
135 | os.mkdir(args.save_path)
136 | save_model_path = args.save_path + 'checkpoint_adaptor_{}epoch.tar'.format(epoch)
137 | torch.save(state, save_model_path)
138 |
139 | torch.cuda.empty_cache()
140 |
141 |
142 | if __name__ == '__main__':
143 | main()
144 |
145 |
--------------------------------------------------------------------------------
/retrain_CostAggregation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.nn as nn
6 | import os
7 | import copy
8 | from tqdm import tqdm, trange
9 | import matplotlib.pyplot as plt
10 | import argparse
11 |
12 | from dataloader import sceneflow_loader as sf
13 | import networks.Aggregator as Agg
14 | import networks.submodule as sm
15 | import networks.U_net as un
16 | import networks.feature_extraction as FE
17 | import loss_functions as lf
18 |
19 |
20 | parser = argparse.ArgumentParser(description='GraftNet')
21 | parser.add_argument('--no_cuda', action='store_true', default=False)
22 | parser.add_argument('--gpu_id', type=str, default='0, 1')
23 | parser.add_argument('--seed', type=str, default=0)
24 | parser.add_argument('--batch_size', type=int, default=6)
25 | parser.add_argument('--epoch', type=int, default=10)
26 | parser.add_argument('--data_path', type=str, default='/media/data/dataset/SceneFlow/')
27 | parser.add_argument('--save_path', type=str, default='trained_models/')
28 | parser.add_argument('--load_path', type=str, default='trained_models/checkpoint_adaptor_1epoch.tar')
29 | parser.add_argument('--max_disp', type=int, default=192)
30 | parser.add_argument('--color_transform', action='store_true', default=False)
31 | args = parser.parse_args()
32 |
33 | if not args.no_cuda:
34 | os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
35 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
36 | cuda = torch.cuda.is_available()
37 |
38 | torch.manual_seed(args.seed)
39 | if cuda:
40 | torch.cuda.manual_seed(args.seed)
41 |
42 | all_limg, all_rimg, all_ldisp, all_rdisp, test_limg, test_rimg, test_ldisp, test_rdisp = sf.sf_loader(args.data_path)
43 |
44 | trainLoader = torch.utils.data.DataLoader(
45 | sf.myDataset(all_limg, all_rimg, all_ldisp, all_rdisp, training=True, color_transform=args.color_transform),
46 | batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=False)
47 |
48 |
49 | fe_model = FE.VGG_Feature(fixed_param=True).eval()
50 | adaptor = un.U_Net_v4(img_ch=256, output_ch=64).eval()
51 | model = Agg.PSMAggregator(args.max_disp, udc=True).train()
52 | print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
53 |
54 | if cuda:
55 | fe_model = nn.DataParallel(fe_model.cuda())
56 | adaptor = nn.DataParallel(adaptor.cuda())
57 | model = nn.DataParallel(model.cuda())
58 |
59 | adaptor.load_state_dict(torch.load(args.load_path)['fa_net'])
60 | for p in adaptor.parameters():
61 | p.requires_grad = False
62 |
63 | optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999))
64 |
65 |
66 | def train(imgL, imgR, gt_left, gt_right):
67 | imgL = torch.FloatTensor(imgL)
68 | imgR = torch.FloatTensor(imgR)
69 | gt_left = torch.FloatTensor(gt_left)
70 | gt_right = torch.FloatTensor(gt_right)
71 |
72 | if cuda:
73 | imgL, imgR, gt_left, gt_right = imgL.cuda(), imgR.cuda(), gt_left.cuda(), gt_right.cuda()
74 |
75 | optimizer.zero_grad()
76 |
77 | with torch.no_grad():
78 | left_fea = fe_model(imgL)
79 | right_fea = fe_model(imgR)
80 |
81 | left_fea = adaptor(left_fea)
82 | right_fea = adaptor(right_fea)
83 |
84 | loss1, loss2 = model(left_fea, right_fea, gt_left, training=True)
85 | loss1 = torch.mean(loss1)
86 | loss2 = torch.mean(loss2)
87 | loss = 0.1 * loss1 + loss2
88 | # loss = loss1
89 |
90 | loss.backward()
91 | optimizer.step()
92 |
93 | return loss1.item(), loss2.item()
94 |
95 |
96 | def adjust_learning_rate(optimizer, epoch):
97 | if epoch <= 5:
98 | lr = 0.001
99 | else:
100 | lr = 0.0001
101 | # print(lr)
102 | for param_group in optimizer.param_groups:
103 | param_group['lr'] = lr
104 |
105 |
106 | def main():
107 |
108 | # start_total_time = time.time()
109 | start_epoch = 1
110 |
111 | # checkpoint = torch.load('trained_ft_costAgg/checkpoint_1_v4.tar')
112 | # CostAggregator.load_state_dict(checkpoint['net'])
113 | # optimizer.load_state_dict(checkpoint['optimizer'])
114 | # start_epoch = checkpoint['epoch'] + 1
115 |
116 | for epoch in range(start_epoch, args.epoch + start_epoch):
117 | print('This is %d-th epoch' % (epoch))
118 | total_train_loss1 = 0
119 | total_train_loss2 = 0
120 | adjust_learning_rate(optimizer, epoch)
121 | #
122 |
123 | for batch_id, (imgL, imgR, disp_L, disp_R) in enumerate(tqdm(trainLoader)):
124 | train_loss1, train_loss2 = train(imgL, imgR, disp_L, disp_R)
125 | total_train_loss1 += train_loss1
126 | total_train_loss2 += train_loss2
127 | avg_train_loss1 = total_train_loss1 / len(trainLoader)
128 | avg_train_loss2 = total_train_loss2 / len(trainLoader)
129 | print('Epoch %d average training loss1 = %.3f, average training loss2 = %.3f' %
130 | (epoch, avg_train_loss1, avg_train_loss2))
131 |
132 | state = {'fa_net': adaptor.state_dict(),
133 | 'net': model.state_dict(),
134 | 'optimizer': optimizer.state_dict(),
135 | 'epoch': epoch}
136 | if not os.path.exists(args.save_path):
137 | os.mkdir(args.save_path)
138 | save_model_path = args.save_path + 'checkpoint_final_{}epoch.tar'.format(epoch)
139 | torch.save(state, save_model_path)
140 |
141 | torch.cuda.empty_cache()
142 |
143 |
144 | if __name__ == '__main__':
145 | main()
146 |
147 |
--------------------------------------------------------------------------------
/loss_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import cv2
6 | from PIL import Image
7 | import matplotlib.pyplot as plt
8 |
9 |
10 | def disp2distribute(disp_gt, max_disp, b=2):
11 | disp_gt = disp_gt.unsqueeze(1)
12 | disp_range = torch.arange(0, max_disp).view(1, -1, 1, 1).float().cuda()
13 | gt_distribute = torch.exp(-torch.abs(disp_range - disp_gt) / b)
14 | gt_distribute = gt_distribute / (torch.sum(gt_distribute, dim=1, keepdim=True) + 1e-8)
15 | return gt_distribute
16 |
17 |
18 | def CEloss(disp_gt, max_disp, gt_distribute, pred_distribute):
19 | mask = (disp_gt > 0) & (disp_gt < max_disp)
20 |
21 | pred_distribute = torch.log(pred_distribute + 1e-8)
22 |
23 | ce_loss = torch.sum(-gt_distribute * pred_distribute, dim=1)
24 | ce_loss = torch.mean(ce_loss[mask])
25 | return ce_loss
26 |
27 |
28 | def gradient_x(img):
29 | img = F.pad(img, (0, 1, 0, 0), mode="replicate")
30 | gx = img[:, :, :, :-1] - img[:, :, :, 1:]
31 | return gx
32 |
33 |
34 | def gradient_y(img):
35 | img = F.pad(img, (0, 0, 0, 1), mode="replicate")
36 | gy = img[:, :, :-1, :] - img[:, :, 1:, :]
37 | return gy
38 |
39 |
40 | def smooth_loss(img, disp):
41 | img_gx = gradient_x(img)
42 | img_gy = gradient_y(img)
43 | disp_gx = gradient_x(disp)
44 | disp_gy = gradient_y(disp)
45 |
46 | weight_x = torch.exp(-torch.mean(torch.abs(img_gx), dim=1, keepdim=True))
47 | weight_y = torch.exp(-torch.mean(torch.abs(img_gy), dim=1, keepdim=True))
48 | smoothness_x = torch.abs(disp_gx * weight_x)
49 | smoothness_y = torch.abs(disp_gy * weight_y)
50 | smoothness_loss = smoothness_x + smoothness_y
51 |
52 | return torch.mean(smoothness_loss)
53 |
54 |
55 |
56 | def occlusion_mask(left_disp, right_disp, threshold=1):
57 | # left_disp = left_disp.unsqueeze(1)
58 | # right_disp = right_disp.unsqueeze(1)
59 |
60 | B, _, H, W = left_disp.size()
61 |
62 | x_base = torch.linspace(0, 1, W).repeat(B, H, 1).type_as(right_disp)
63 | y_base = torch.linspace(0, 1, H).repeat(B, W, 1).transpose(1, 2).type_as(right_disp)
64 |
65 | flow_field = torch.stack((x_base - left_disp.squeeze(1) / W, y_base), dim=3)
66 |
67 | recon_left_disp = F.grid_sample(right_disp, 2 * flow_field - 1, mode='bilinear', padding_mode='zeros')
68 |
69 | lr_check = torch.abs(recon_left_disp - left_disp)
70 | mask = lr_check > threshold
71 |
72 | return mask
73 |
74 |
75 | def reconstruction(right, disp):
76 | b, _, h, w = right.size()
77 |
78 | x_base = torch.linspace(0, 1, w).repeat(b, h, 1).type_as(right)
79 | y_base = torch.linspace(0, 1, h).repeat(b, w, 1).transpose(1, 2).type_as(right)
80 |
81 | flow_field = torch.stack((x_base - disp / w, y_base), dim=3)
82 |
83 | recon_left = F.grid_sample(right, 2 * flow_field - 1, mode='bilinear', padding_mode='zeros')
84 | return recon_left
85 |
86 |
87 | def NT_Xent_loss(positive_simi, negative_simi, t):
88 | loss = torch.exp(positive_simi / t) / \
89 | (torch.exp(positive_simi / t) + torch.sum(torch.exp(negative_simi / t), dim=4))
90 | loss = -torch.log(loss + 1e-9)
91 | return loss
92 |
93 |
94 | class FeatureSimilarityLoss(nn.Module):
95 | def __init__(self, max_disp):
96 | super(FeatureSimilarityLoss, self).__init__()
97 | self.max_disp = max_disp
98 | self.m = 0.3
99 | self.nega_num = 1
100 |
101 | def forward(self, left_fea, right_fea, left_disp, right_disp):
102 | B, _, H, W = left_fea.size()
103 |
104 | down_disp = F.interpolate(left_disp, (H, W), mode='nearest') / 4.
105 | # down_img = F.interpolate(left_img, (H, W), mode='nearest')
106 | # down_img = torch.mean(down_img, dim=1, keepdim=True)
107 |
108 | # t_map = self.t_net(left_fea)
109 |
110 | # create negative samples
111 | random_offset = torch.rand(B, self.nega_num, H, W).cuda() * 2 + 1
112 | random_sign = torch.sign(torch.rand(B, self.nega_num, H, W).cuda() - 0.5)
113 | random_offset *= random_sign
114 | negative_disp = down_disp + random_offset
115 |
116 | positive_recon = reconstruction(right_fea, down_disp.squeeze(1))
117 | negative_recon = []
118 | for i in range(self.nega_num):
119 | negative_recon.append(reconstruction(right_fea, negative_disp[:, i]))
120 | negative_recon = torch.stack(negative_recon, dim=4)
121 |
122 | left_fea = F.normalize(left_fea, dim=1)
123 | positive_recon = F.normalize(positive_recon, dim=1)
124 | negative_recon = F.normalize(negative_recon, dim=1)
125 |
126 | positive_simi = (torch.sum(left_fea * positive_recon, dim=1, keepdim=True) + 1) / 2
127 | negative_simi = (torch.sum(left_fea.unsqueeze(4) * negative_recon, dim=1, keepdim=True) + 1) / 2
128 |
129 | judge_mat_p = torch.zeros_like(positive_simi)
130 | judge_mat_n = torch.zeros_like(negative_simi)
131 | if torch.sum(positive_simi < judge_mat_p) > 0 or torch.sum(negative_simi < judge_mat_n) > 0:
132 | print('cosine_simi < 0')
133 |
134 | # hinge loss
135 | # dist = self.m + negative_simi - positive_simi
136 | # criteria = torch.zeros_like(dist)
137 | # loss, _ = torch.max(torch.cat((dist, criteria), dim=1), dim=1, keepdim=True)
138 |
139 | # NT-Xent loss
140 | # loss = NT_Xent_loss(positive_simi, negative_simi, t=t_map)
141 | loss = NT_Xent_loss(positive_simi, negative_simi, t=0.2)
142 |
143 | # img_grad = torch.sqrt(gradient_x(down_img) ** 2 + gradient_y(down_img) ** 2)
144 | # weight = torch.exp(-img_grad)
145 | # loss = loss * weight
146 |
147 | occ_mask = occlusion_mask(left_disp, right_disp, threshold=1)
148 | occ_mask = F.interpolate(occ_mask.float(), (H, W), mode='nearest')
149 | valid_mask = (down_disp > 0) & (down_disp < self.max_disp // 4) & (occ_mask == 0)
150 |
151 | return torch.mean(loss[valid_mask])
152 |
153 |
154 | def gram_matrix(feature):
155 | B, C, H, W = feature.size()
156 | feature = feature.view(B, C, H * W)
157 | feature_t = feature.transpose(1, 2)
158 | gram_m = torch.bmm(feature, feature_t) / (H * W)
159 | return gram_m
160 |
161 |
162 | def gram_matrix_v2(feature):
163 | B, C, H, W = feature.size()
164 | feature = feature.view(B * C, H * W)
165 | gram_m = torch.mm(feature, feature.t()) / (B * C * H * W)
166 | return gram_m
167 |
168 |
169 | if __name__ == '__main__':
170 |
171 | a = torch.rand(2, 256, 64, 128)
172 | b = torch.rand(2, 256, 64, 128)
173 |
174 | gram_a = gram_matrix(a)
175 | gram_b = gram_matrix(b)
176 | print(F.mse_loss(gram_a, gram_b))
177 |
178 | ga_2 = gram_matrix_v2(a)
179 | gb_2 = gram_matrix_v2(b)
180 | print(F.mse_loss(ga_2, gb_2))
181 |
--------------------------------------------------------------------------------
/networks/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | __all__ = [
6 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
7 | 'vgg19_bn', 'vgg19',
8 | ]
9 |
10 |
11 | class VGG(nn.Module):
12 |
13 | def __init__(self, features, num_classes=1000, init_weights=True):
14 | super(VGG, self).__init__()
15 | self.features = features
16 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
17 | self.classifier = nn.Sequential(
18 | nn.Linear(512 * 7 * 7, 4096),
19 | nn.ReLU(True),
20 | nn.Dropout(),
21 | nn.Linear(4096, 4096),
22 | nn.ReLU(True),
23 | nn.Dropout(),
24 | nn.Linear(4096, num_classes),
25 | )
26 | if init_weights:
27 | self._initialize_weights()
28 |
29 | def forward(self, x):
30 | x = self.features(x)
31 | x = self.avgpool(x)
32 | x = torch.flatten(x, 1)
33 | x = self.classifier(x)
34 | return x
35 |
36 | def _initialize_weights(self):
37 | for m in self.modules():
38 | if isinstance(m, nn.Conv2d):
39 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
40 | if m.bias is not None:
41 | nn.init.constant_(m.bias, 0)
42 | elif isinstance(m, nn.BatchNorm2d):
43 | nn.init.constant_(m.weight, 1)
44 | nn.init.constant_(m.bias, 0)
45 | elif isinstance(m, nn.Linear):
46 | nn.init.normal_(m.weight, 0, 0.01)
47 | nn.init.constant_(m.bias, 0)
48 |
49 |
50 | def make_layers(cfg, batch_norm=False):
51 | layers = []
52 | in_channels = 3
53 | for v in cfg:
54 | if v == 'M':
55 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
56 | else:
57 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
58 | if batch_norm:
59 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
60 | else:
61 | layers += [conv2d, nn.ReLU(inplace=True)]
62 | in_channels = v
63 | return nn.Sequential(*layers)
64 |
65 |
66 | cfgs = {
67 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
68 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
69 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
70 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
71 | }
72 |
73 |
74 | def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
75 | if pretrained:
76 | kwargs['init_weights'] = False
77 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
78 | return model
79 |
80 |
81 | def vgg11(pretrained=False, progress=True, **kwargs):
82 | r"""VGG 11-layer model (configuration "A") from
83 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
84 |
85 | Args:
86 | pretrained (bool): If True, returns a model pre-trained on ImageNet
87 | progress (bool): If True, displays a progress bar of the download to stderr
88 | """
89 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
90 |
91 |
92 | def vgg11_bn(pretrained=False, progress=True, **kwargs):
93 | r"""VGG 11-layer model (configuration "A") with batch normalization
94 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
95 |
96 | Args:
97 | pretrained (bool): If True, returns a model pre-trained on ImageNet
98 | progress (bool): If True, displays a progress bar of the download to stderr
99 | """
100 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
101 |
102 |
103 | def vgg13(pretrained=False, progress=True, **kwargs):
104 | r"""VGG 13-layer model (configuration "B")
105 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
106 |
107 | Args:
108 | pretrained (bool): If True, returns a model pre-trained on ImageNet
109 | progress (bool): If True, displays a progress bar of the download to stderr
110 | """
111 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
112 |
113 |
114 | def vgg13_bn(pretrained=False, progress=True, **kwargs):
115 | r"""VGG 13-layer model (configuration "B") with batch normalization
116 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
117 |
118 | Args:
119 | pretrained (bool): If True, returns a model pre-trained on ImageNet
120 | progress (bool): If True, displays a progress bar of the download to stderr
121 | """
122 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
123 |
124 |
125 | def vgg16(pretrained=False, progress=True, **kwargs):
126 | r"""VGG 16-layer model (configuration "D")
127 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
128 |
129 | Args:
130 | pretrained (bool): If True, returns a model pre-trained on ImageNet
131 | progress (bool): If True, displays a progress bar of the download to stderr
132 | """
133 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
134 |
135 |
136 | def vgg16_bn(pretrained=False, progress=True, **kwargs):
137 | r"""VGG 16-layer model (configuration "D") with batch normalization
138 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
139 |
140 | Args:
141 | pretrained (bool): If True, returns a model pre-trained on ImageNet
142 | progress (bool): If True, displays a progress bar of the download to stderr
143 | """
144 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
145 |
146 |
147 | def vgg19(pretrained=False, progress=True, **kwargs):
148 | r"""VGG 19-layer model (configuration "E")
149 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
150 |
151 | Args:
152 | pretrained (bool): If True, returns a model pre-trained on ImageNet
153 | progress (bool): If True, displays a progress bar of the download to stderr
154 | """
155 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
156 |
157 |
158 | def vgg19_bn(pretrained=False, progress=True, **kwargs):
159 | r"""VGG 19-layer model (configuration 'E') with batch normalization
160 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
161 |
162 | Args:
163 | pretrained (bool): If True, returns a model pre-trained on ImageNet
164 | progress (bool): If True, displays a progress bar of the download to stderr
165 | """
166 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
167 |
--------------------------------------------------------------------------------
/networks/submodule.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import torch.nn as nn
4 | import torch.utils.data
5 | from torch.autograd import Variable
6 | import torch.nn.functional as F
7 | from torchvision import models
8 | import math
9 | import numpy as np
10 | import torchvision.transforms as transforms
11 | import PIL
12 | import os
13 | import matplotlib.pyplot as plt
14 | from networks.resnet import ResNet, Bottleneck, BasicBlock_Res
15 |
16 |
17 | def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation):
18 |
19 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad, dilation = dilation, bias=False),
20 | nn.BatchNorm2d(out_planes))
21 |
22 |
23 | def convbn_3d(in_planes, out_planes, kernel_size, stride, pad):
24 |
25 | return nn.Sequential(nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, padding=pad, stride=stride,bias=False),
26 | nn.BatchNorm3d(out_planes))
27 |
28 |
29 | class BasicBlock(nn.Module):
30 | expansion = 1
31 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation):
32 | super(BasicBlock, self).__init__()
33 |
34 | self.conv1 = nn.Sequential(convbn(inplanes, planes, 3, stride, pad, dilation),
35 | nn.ReLU(inplace=True))
36 |
37 | self.conv2 = convbn(planes, planes, 3, 1, pad, dilation)
38 |
39 | self.downsample = downsample
40 | self.stride = stride
41 |
42 | def forward(self, x):
43 | out = self.conv1(x)
44 | out = self.conv2(out)
45 |
46 | if self.downsample is not None:
47 | x = self.downsample(x)
48 |
49 | out += x
50 |
51 | return out
52 |
53 |
54 | class DisparityRegression(nn.Module):
55 |
56 | def __init__(self, maxdisp, win_size):
57 | super(DisparityRegression, self).__init__()
58 | self.max_disp = maxdisp
59 | self.win_size = win_size
60 |
61 | def forward(self, x):
62 | disp = torch.arange(0, self.max_disp).view(1, -1, 1, 1).float().to(x.device)
63 |
64 | if self.win_size > 0:
65 | max_d = torch.argmax(x, dim=1, keepdim=True)
66 | d_value = []
67 | prob_value = []
68 | for d in range(-self.win_size, self.win_size + 1):
69 | index = max_d + d
70 | index[index < 0] = 0
71 | index[index > x.shape[1] - 1] = x.shape[1] - 1
72 | d_value.append(index)
73 |
74 | prob = torch.gather(x, dim=1, index=index)
75 | prob_value.append(prob)
76 |
77 | part_x = torch.cat(prob_value, dim=1)
78 | part_x = part_x / (torch.sum(part_x, dim=1, keepdim=True) + 1e-8)
79 | part_d = torch.cat(d_value, dim=1).float()
80 | out = torch.sum(part_x * part_d, dim=1)
81 |
82 | else:
83 | out = torch.sum(x * disp, 1)
84 |
85 | return out
86 |
87 |
88 | class GwcFeature(nn.Module):
89 | def __init__(self, out_c, fuse_mode='add'):
90 | super(GwcFeature, self).__init__()
91 | self.inplanes = 32
92 | self.fuse_mode = fuse_mode
93 |
94 | self.firstconv = nn.Sequential(convbn(3, 32, 3, 2, 1, 1),
95 | nn.ReLU(inplace=True),
96 | convbn(32, 32, 3, 1, 1, 1),
97 | nn.ReLU(inplace=True),
98 | convbn(32, 32, 3, 1, 1, 1),
99 | nn.ReLU(inplace=True))
100 |
101 | self.layer1 = self._make_layer(BasicBlock, 32, 3, 1, 1, 1)
102 | self.layer2 = self._make_layer(BasicBlock, 64, 16, 2, 1, 1)
103 | self.layer3 = self._make_layer(BasicBlock, 128, 3, 1, 1, 1)
104 | self.layer4 = self._make_layer(BasicBlock, 128, 3, 1, 1, 2)
105 |
106 | if self.fuse_mode == 'cat':
107 | self.lastconv = nn.Sequential(convbn(320, 128, 3, 1, 1, 1),
108 | nn.ReLU(inplace=True),
109 | nn.Conv2d(128, out_c, kernel_size=1, padding=0, stride=1, bias=False))
110 | elif self.fuse_mode == 'add':
111 | self.l1_conv = nn.Conv2d(32, out_c, 1, stride=1, padding=0, bias=False)
112 | self.l2_conv = nn.Conv2d(64, out_c, 1, stride=1, padding=0, bias=False)
113 | self.l4_conv = nn.Conv2d(128, out_c, 1, stride=1, padding=0, bias=False)
114 | elif self.fuse_mode == 'add_sa':
115 | self.l1_conv = nn.Conv2d(64, out_c, 1, stride=1, padding=0, bias=False)
116 | self.l4_conv = nn.Conv2d(64, out_c, 1, stride=1, padding=0, bias=False)
117 | self.sa = nn.Sequential(convbn(2 * out_c, 2 * out_c, 3, 1, 1, 1),
118 | nn.ReLU(inplace=True),
119 | nn.Conv2d(2 * out_c, 2, 3, stride=1, padding=1, bias=False))
120 |
121 | def _make_layer(self, block, planes, blocks, stride, pad, dilation):
122 | downsample = None
123 | if stride != 1 or self.inplanes != planes * block.expansion:
124 | downsample = nn.Sequential(
125 | nn.Conv2d(self.inplanes, planes * block.expansion,
126 | kernel_size=1, stride=stride, bias=False),
127 | nn.BatchNorm2d(planes * block.expansion),)
128 |
129 | layers = []
130 | layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation))
131 | self.inplanes = planes * block.expansion
132 | for i in range(1, blocks):
133 | layers.append(block(self.inplanes, planes,1,None,pad,dilation))
134 |
135 | return nn.Sequential(*layers)
136 |
137 | def forward(self, x):
138 | output = self.firstconv(x)
139 | output_l1 = self.layer1(output)
140 | output_l2 = self.layer2(output_l1)
141 | output_l3 = self.layer3(output_l2)
142 | output_l4 = self.layer4(output_l3)
143 |
144 | output_l1 = F.interpolate(output_l1, (output_l4.size()[2], output_l4.size()[3]),
145 | mode='bilinear', align_corners=True)
146 |
147 | if self.fuse_mode == 'cat':
148 | cat_feature = torch.cat((output_l2, output_l3, output_l4), dim=1)
149 | output_feature = self.lastconv(cat_feature)
150 | elif self.fuse_mode == 'add':
151 | output_l1 = self.l1_conv(output_l1)
152 | output_l4 = self.l4_conv(output_l4)
153 | output_feature = output_l1 + output_l4
154 | elif self.fuse_mode == 'add_sa':
155 | output_l1 = self.l1_conv(output_l1)
156 | output_l4 = self.l4_conv(output_l4)
157 |
158 | attention_map = self.sa(torch.cat((output_l1, output_l4), dim=1))
159 | attention_map = torch.sigmoid(attention_map)
160 | output_feature = output_l1 * attention_map[:, 0].unsqueeze(1) + \
161 | output_l4 * attention_map[:, 1].unsqueeze(1)
162 |
163 | return output_feature
164 |
165 |
166 |
--------------------------------------------------------------------------------
/dataloader/sceneflow_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from dataloader import readpfm as rp
4 | import dataloader.preprocess
5 | import torch.utils.data as data
6 | import torchvision.transforms as transforms
7 | import numpy as np
8 | import random
9 |
10 | IMG_EXTENSIONS= [
11 | '.jpg', '.JPG', '.jpeg', '.JPEG',
12 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'
13 | ]
14 |
15 |
16 | def is_image_file(filename):
17 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
18 |
19 |
20 | # filepath = '/media/data/LiuBiyang/SceneFlow/'
21 | def sf_loader(filepath):
22 |
23 | classes = [d for d in os.listdir(filepath) if os.path.isdir(os.path.join(filepath, d))]
24 | image = [img for img in classes if img.find('frames_cleanpass') > -1]
25 | disparity = [disp for disp in classes if disp.find('disparity') > -1]
26 |
27 | all_left_img = []
28 | all_right_img = []
29 | all_left_disp = []
30 | all_right_disp = []
31 | test_left_img = []
32 | test_right_img = []
33 | test_left_disp = []
34 | test_right_disp = []
35 |
36 | monkaa_img = filepath + [x for x in image if 'monkaa' in x][0]
37 | monkaa_disp = filepath + [x for x in disparity if 'monkaa' in x][0]
38 | monkaa_dir = os.listdir(monkaa_img)
39 | for dd in monkaa_dir:
40 | left_path = monkaa_img + '/' + dd + '/left/'
41 | right_path = monkaa_img + '/' + dd + '/right/'
42 | disp_path = monkaa_disp + '/' + dd + '/left/'
43 | rdisp_path = monkaa_disp + '/' + dd + '/right/'
44 |
45 | left_imgs = os.listdir(left_path)
46 | for img in left_imgs:
47 | img_path = os.path.join(left_path, img)
48 | if is_image_file(img_path):
49 | all_left_img.append(img_path)
50 | all_right_img.append(os.path.join(right_path, img))
51 | all_left_disp.append(disp_path + img.split(".")[0] + '.pfm')
52 | all_right_disp.append(rdisp_path + img.split(".")[0] + '.pfm')
53 |
54 | flying_img = filepath + [x for x in image if 'flying' in x][0]
55 | flying_disp = filepath + [x for x in disparity if 'flying' in x][0]
56 | fimg_train = flying_img + '/TRAIN/'
57 | fimg_test = flying_img + '/TEST/'
58 | fdisp_train = flying_disp + '/TRAIN/'
59 | fdisp_test = flying_disp + '/TEST/'
60 | fsubdir = ['A', 'B', 'C']
61 |
62 | for dd in fsubdir:
63 | imgs_path = fimg_train + dd + '/'
64 | disps_path = fdisp_train + dd + '/'
65 | imgs = os.listdir(imgs_path)
66 | for cc in imgs:
67 | left_path = imgs_path + cc + '/left/'
68 | right_path = imgs_path + cc + '/right/'
69 | disp_path = disps_path + cc + '/left/'
70 | rdisp_path = disps_path + cc + '/right/'
71 |
72 | left_imgs = os.listdir(left_path)
73 | for img in left_imgs:
74 | img_path = os.path.join(left_path, img)
75 | if is_image_file(img_path):
76 | all_left_img.append(img_path)
77 | all_right_img.append(os.path.join(right_path, img))
78 | all_left_disp.append(disp_path + img.split(".")[0] + '.pfm')
79 | all_right_disp.append(rdisp_path + img.split(".")[0] + '.pfm')
80 |
81 | for dd in fsubdir:
82 | imgs_path = fimg_test + dd + '/'
83 | disps_path = fdisp_test + dd + '/'
84 | imgs = os.listdir(imgs_path)
85 | for cc in imgs:
86 | left_path = imgs_path + cc + '/left/'
87 | right_path = imgs_path + cc + '/right/'
88 | disp_path = disps_path + cc + '/left/'
89 | rdisp_path = disps_path + cc + '/right/'
90 |
91 | left_imgs = os.listdir(left_path)
92 | for img in left_imgs:
93 | img_path = os.path.join(left_path, img)
94 | if is_image_file(img_path):
95 | test_left_img.append(img_path)
96 | test_right_img.append(os.path.join(right_path, img))
97 | test_left_disp.append(disp_path + img.split(".")[0] + '.pfm')
98 | test_right_disp.append(rdisp_path + img.split(".")[0] + '.pfm')
99 |
100 | driving_img = filepath + [x for x in image if 'driving' in x][0]
101 | driving_disp = filepath + [x for x in disparity if 'driving' in x][0]
102 | dsubdir1 = ['15mm_focallength', '35mm_focallength']
103 | dsubdir2 = ['scene_backwards', 'scene_forwards']
104 | dsubdir3 = ['fast', 'slow']
105 | for d in dsubdir1:
106 | img_path1 = driving_img + '/' + d + '/'
107 | disp_path1 = driving_disp + '/' + d + '/'
108 | for dd in dsubdir2:
109 | img_path2 = img_path1 + dd + '/'
110 | disp_path2 = disp_path1 + dd + '/'
111 | for ddd in dsubdir3:
112 | img_path3 = img_path2 + ddd + '/'
113 | disp_path3 = disp_path2 + ddd + '/'
114 |
115 | left_path = img_path3 + 'left/'
116 | right_path = img_path3 + 'right/'
117 | disp_path = disp_path3 + 'left/'
118 | rdisp_path = disp_path3 + 'right/'
119 |
120 | left_imgs = os.listdir(left_path)
121 | for img in left_imgs:
122 | img_path = os.path.join(left_path, img)
123 | if is_image_file(img_path):
124 | all_left_img.append(img_path)
125 | all_right_img.append(os.path.join(right_path, img))
126 | all_left_disp.append(disp_path + img.split(".")[0] + '.pfm')
127 | all_right_disp.append(rdisp_path + img.split(".")[0] + '.pfm')
128 |
129 | return all_left_img, all_right_img, all_left_disp, all_right_disp, \
130 | test_left_img, test_right_img, test_left_disp, test_right_disp
131 |
132 |
133 | def img_loader(path):
134 | return Image.open(path).convert('RGB')
135 |
136 |
137 | def disparity_loader(path):
138 | return rp.readPFM(path)
139 |
140 |
141 | def random_transform(left_img, right_img):
142 | if np.random.rand(1) <= 0.2:
143 | left_img = transforms.Grayscale(num_output_channels=3)(left_img)
144 | right_img = transforms.Grayscale(num_output_channels=3)(right_img)
145 | else:
146 | left_img = transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.1)(left_img)
147 | right_img = transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.1)(right_img)
148 | return left_img, right_img
149 |
150 |
151 | class myDataset(data.Dataset):
152 |
153 | def __init__(self, left, right, left_disp, right_disp, training, imgloader=img_loader, dploader = disparity_loader,
154 | color_transform = False):
155 | self.left = left
156 | self.right = right
157 | self.disp_L = left_disp
158 | self.disp_R = right_disp
159 | self.imgloader = imgloader
160 | self.dploader = dploader
161 | self.training = training
162 | self.img_transorm = transforms.Compose([
163 | transforms.ToTensor(),
164 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
165 | self.color_transform = color_transform
166 |
167 | def __getitem__(self, index):
168 | left = self.left[index]
169 | right = self.right[index]
170 | disp_L = self.disp_L[index]
171 | disp_R = self.disp_R[index]
172 |
173 | left_img = self.imgloader(left)
174 | right_img = self.imgloader(right)
175 | dataL, _ = self.dploader(disp_L)
176 | dataL = np.ascontiguousarray(dataL, dtype=np.float32)
177 | dataR, _ = self.dploader(disp_R)
178 | dataR = np.ascontiguousarray(dataR, dtype=np.float32)
179 |
180 | if self.training:
181 | w, h = left_img.size
182 | tw, th = 512, 256
183 | x1 = random.randint(0, w - tw)
184 | y1 = random.randint(0, h - th)
185 |
186 | left_img = left_img.crop((x1, y1, x1+tw, y1+th))
187 | right_img = right_img.crop((x1, y1, x1+tw, y1+th))
188 | dataL = dataL[y1:y1+th, x1:x1+tw]
189 | dataR = dataR[y1:y1+th, x1:x1+tw]
190 |
191 | if self.color_transform:
192 | left_img, right_img = random_transform(left_img, right_img)
193 |
194 | left_img = self.img_transorm(left_img)
195 | right_img = self.img_transorm(right_img)
196 |
197 | return left_img, right_img, dataL, dataR
198 |
199 | else:
200 | w, h = left_img.size
201 | left_img = left_img.crop((w-960, h-544, w, h))
202 | right_img = right_img.crop((w-960, h-544, w, h))
203 |
204 | left_img = self.img_transorm(left_img)
205 | right_img = self.img_transorm(right_img)
206 |
207 | dataL = Image.fromarray(dataL).crop((w-960, h-544, w, h))
208 | dataL = np.ascontiguousarray(dataL)
209 | dataR = Image.fromarray(dataR).crop((w-960, h-544, w, h))
210 | dataR = np.ascontiguousarray(dataR)
211 |
212 | return left_img, right_img, dataL, dataR
213 |
214 | def __len__(self):
215 | return len(self.left)
216 |
217 |
218 |
219 |
220 |
221 |
--------------------------------------------------------------------------------
/networks/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
7 | """3x3 convolution with padding"""
8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
9 | padding=dilation, groups=groups, bias=False, dilation=dilation)
10 |
11 |
12 | def conv1x1(in_planes, out_planes, stride=1):
13 | """1x1 convolution"""
14 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
15 |
16 |
17 | class BasicBlock_Res(nn.Module):
18 | expansion = 1
19 |
20 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
21 | base_width=64, dilation=1, norm_layer=None, use_relu=True):
22 | super(BasicBlock_Res, self).__init__()
23 | if norm_layer is None:
24 | norm_layer = nn.BatchNorm2d
25 | if groups != 1 or base_width != 64:
26 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
27 | if dilation > 1:
28 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
29 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
30 | self.conv1 = conv3x3(inplanes, planes, stride)
31 | self.bn1 = norm_layer(planes)
32 | self.relu = nn.ReLU(inplace=True)
33 | self.conv2 = conv3x3(planes, planes)
34 | self.bn2 = norm_layer(planes)
35 | self.downsample = downsample
36 | self.stride = stride
37 |
38 | self.use_relu = use_relu
39 |
40 | def forward(self, x):
41 | identity = x
42 |
43 | out = self.conv1(x)
44 | out = self.bn1(out)
45 | out = self.relu(out)
46 |
47 | out = self.conv2(out)
48 | out = self.bn2(out)
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 |
55 | if self.use_relu:
56 | out = self.relu(out)
57 | return out
58 |
59 |
60 | class Bottleneck(nn.Module):
61 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
62 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
63 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
64 | # This variant is also known as ResNet V1.5 and improves accuracy according to
65 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
66 |
67 | expansion = 4
68 |
69 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
70 | base_width=64, dilation=1, norm_layer=None, use_relu=True):
71 | super(Bottleneck, self).__init__()
72 | if norm_layer is None:
73 | norm_layer = nn.BatchNorm2d
74 | width = int(planes * (base_width / 64.)) * groups
75 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
76 | self.conv1 = conv1x1(inplanes, width)
77 | self.bn1 = norm_layer(width)
78 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
79 | self.bn2 = norm_layer(width)
80 | self.conv3 = conv1x1(width, planes * self.expansion)
81 | self.bn3 = norm_layer(planes * self.expansion)
82 | self.relu = nn.ReLU(inplace=True)
83 | self.downsample = downsample
84 | self.stride = stride
85 |
86 | self.use_relu = use_relu
87 |
88 | def forward(self, x):
89 | identity = x
90 |
91 | out = self.conv1(x)
92 | out = self.bn1(out)
93 | out = self.relu(out)
94 |
95 | out = self.conv2(out)
96 | out = self.bn2(out)
97 | out = self.relu(out)
98 |
99 | out = self.conv3(out)
100 | out = self.bn3(out)
101 |
102 | if self.downsample is not None:
103 | identity = self.downsample(x)
104 |
105 | out += identity
106 |
107 | if self.use_relu:
108 | out = self.relu(out)
109 |
110 | return out
111 |
112 |
113 | class ResNet(nn.Module):
114 |
115 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
116 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
117 | norm_layer=None):
118 | super(ResNet, self).__init__()
119 | if norm_layer is None:
120 | norm_layer = nn.BatchNorm2d
121 | self._norm_layer = norm_layer
122 |
123 | self.inplanes = 64
124 | self.dilation = 1
125 | if replace_stride_with_dilation is None:
126 | # each element in the tuple indicates if we should replace
127 | # the 2x2 stride with a dilated convolution instead
128 | replace_stride_with_dilation = [False, False, False]
129 | if len(replace_stride_with_dilation) != 3:
130 | raise ValueError("replace_stride_with_dilation should be None "
131 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
132 | self.groups = groups
133 | self.base_width = width_per_group
134 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
135 | bias=False)
136 | self.bn1 = norm_layer(self.inplanes)
137 | self.relu = nn.ReLU(inplace=True)
138 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
139 | self.layer1 = self._make_layer(block, 64, layers[0])
140 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
141 | dilate=replace_stride_with_dilation[0])
142 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
143 | dilate=replace_stride_with_dilation[1])
144 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
145 | dilate=replace_stride_with_dilation[2])
146 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
147 |
148 | # if DenseCl, comment this line
149 | self.fc = nn.Linear(512 * block.expansion, num_classes)
150 |
151 | for m in self.modules():
152 | if isinstance(m, nn.Conv2d):
153 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
154 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
155 | nn.init.constant_(m.weight, 1)
156 | nn.init.constant_(m.bias, 0)
157 |
158 | # Zero-initialize the last BN in each residual branch,
159 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
160 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
161 | if zero_init_residual:
162 | for m in self.modules():
163 | if isinstance(m, Bottleneck):
164 | nn.init.constant_(m.bn3.weight, 0)
165 | elif isinstance(m, BasicBlock_Res):
166 | nn.init.constant_(m.bn2.weight, 0)
167 |
168 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
169 | norm_layer = self._norm_layer
170 | downsample = None
171 | previous_dilation = self.dilation
172 | if dilate:
173 | self.dilation *= stride
174 | stride = 1
175 |
176 | # stride = 1
177 |
178 | if stride != 1 or self.inplanes != planes * block.expansion:
179 | downsample = nn.Sequential(
180 | conv1x1(self.inplanes, planes * block.expansion, stride),
181 | norm_layer(planes * block.expansion),
182 | )
183 |
184 | layers = []
185 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
186 | self.base_width, previous_dilation, norm_layer))
187 | self.inplanes = planes * block.expansion
188 | for i in range(1, blocks):
189 | if i == blocks - 1:
190 | layers.append(block(self.inplanes, planes, groups=self.groups,
191 | base_width=self.base_width, dilation=self.dilation,
192 | norm_layer=norm_layer, use_relu=False))
193 | else:
194 | layers.append(block(self.inplanes, planes, groups=self.groups,
195 | base_width=self.base_width, dilation=self.dilation,
196 | norm_layer=norm_layer, use_relu=True))
197 |
198 | return nn.Sequential(*layers)
199 |
200 | def _forward_impl(self, x):
201 | # See note [TorchScript super()]
202 | x = self.conv1(x)
203 | x = self.bn1(x)
204 | x = self.relu(x)
205 | x = self.maxpool(x)
206 | x = self.layer1(x)
207 |
208 | # x = self.relu(x)
209 |
210 | # x = self.layer2(x)
211 |
212 | # x = self.relu(x)
213 |
214 | # x = self.layer3(x)
215 | # x = self.layer4(x)
216 | #
217 | # x = self.avgpool(x)
218 | # x = torch.flatten(x, 1)
219 |
220 | # x = self.fc(x)
221 |
222 | return x
223 |
224 | def forward(self, x):
225 | return self._forward_impl(x)
226 |
227 |
--------------------------------------------------------------------------------
/networks/U_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 |
6 | class conv_block(nn.Module):
7 | def __init__(self, ch_in, ch_out):
8 | super(conv_block, self).__init__()
9 | self.conv = nn.Sequential(
10 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=False),
11 | nn.BatchNorm2d(ch_out),
12 | nn.ReLU(inplace=True),
13 | nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=False),
14 | nn.BatchNorm2d(ch_out),
15 | nn.ReLU(inplace=True)
16 | )
17 |
18 | def forward(self, x):
19 | x = self.conv(x)
20 | return x
21 |
22 |
23 | class up_conv(nn.Module):
24 | def __init__(self, ch_in, ch_out):
25 | super(up_conv, self).__init__()
26 | self.up = nn.Sequential(
27 | nn.Upsample(scale_factor=2),
28 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
29 | nn.BatchNorm2d(ch_out),
30 | nn.ReLU(inplace=True)
31 | )
32 |
33 | def forward(self, x):
34 | x = self.up(x)
35 | return x
36 |
37 |
38 | class U_Net(nn.Module):
39 | def __init__(self, img_ch=3, output_ch=1):
40 | super(U_Net, self).__init__()
41 |
42 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
43 |
44 | self.Conv1 = conv_block(ch_in=img_ch, ch_out=32)
45 | self.Conv2 = conv_block(ch_in=32, ch_out=64)
46 | self.Conv3 = conv_block(ch_in=64, ch_out=128)
47 | self.Conv4 = conv_block(ch_in=128, ch_out=256)
48 | # self.Conv5 = conv_block(ch_in=256, ch_out=512)
49 | self.Conv5 = conv_block(ch_in=256, ch_out=256)
50 |
51 | # self.Up5 = up_conv(ch_in=512, ch_out=256)
52 | self.Up5 = up_conv(ch_in=256, ch_out=256)
53 | self.Up_conv5 = conv_block(ch_in=512, ch_out=256)
54 |
55 | self.Up4 = up_conv(ch_in=256, ch_out=128)
56 | self.Up_conv4 = conv_block(ch_in=256, ch_out=128)
57 |
58 | self.Up3 = up_conv(ch_in=128, ch_out=64)
59 | self.Up_conv3 = conv_block(ch_in=128, ch_out=64)
60 |
61 | self.Up2 = up_conv(ch_in=64, ch_out=32)
62 | self.Up_conv2 = conv_block(ch_in=64, ch_out=32)
63 |
64 | self.Conv_1x1 = nn.Conv2d(32, output_ch, kernel_size=1, stride=1, padding=0)
65 |
66 | for m in self.modules():
67 | if isinstance(m, nn.Conv2d):
68 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
69 | m.weight.data.normal_(0, math.sqrt(2. / n))
70 | elif isinstance(m, nn.BatchNorm2d):
71 | m.weight.data.fill_(1)
72 | m.bias.data.zero_()
73 | elif isinstance(m, nn.Linear):
74 | m.bias.data.zero_()
75 |
76 | def forward(self, x):
77 | x1 = self.Conv1(x)
78 |
79 | x2 = self.Maxpool(x1)
80 | x2 = self.Conv2(x2)
81 |
82 | x3 = self.Maxpool(x2)
83 | x3 = self.Conv3(x3)
84 |
85 | x4 = self.Maxpool(x3)
86 | x4 = self.Conv4(x4)
87 |
88 | x5 = self.Maxpool(x4)
89 | x5 = self.Conv5(x5)
90 |
91 | d5 = self.Up5(x5)
92 | d5 = torch.cat((x4, d5), dim=1)
93 | d5 = self.Up_conv5(d5)
94 |
95 | d4 = self.Up4(d5)
96 | d4 = torch.cat((x3, d4), dim=1)
97 | d4 = self.Up_conv4(d4)
98 |
99 | d3 = self.Up3(d4)
100 | d3 = torch.cat((x2, d3), dim=1)
101 | d3 = self.Up_conv3(d3)
102 |
103 | d2 = self.Up2(d3)
104 | d2 = torch.cat((x1, d2), dim=1)
105 | d2 = self.Up_conv2(d2)
106 |
107 | d1 = self.Conv_1x1(d2)
108 |
109 | return d1
110 |
111 |
112 | class U_Net_v2(nn.Module):
113 | def __init__(self, img_ch=3, output_ch=1):
114 | super(U_Net_v2, self).__init__()
115 |
116 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
117 |
118 | self.Conv1 = conv_block(ch_in=img_ch, ch_out=32)
119 | self.Conv2 = conv_block(ch_in=32, ch_out=64)
120 | self.Conv3 = conv_block(ch_in=64, ch_out=128)
121 | self.Conv4 = conv_block(ch_in=128, ch_out=256)
122 |
123 | self.Up4 = up_conv(ch_in=256, ch_out=128)
124 | self.Up_conv4 = conv_block(ch_in=256, ch_out=128)
125 |
126 | self.Up3 = up_conv(ch_in=128, ch_out=64)
127 | self.Up_conv3 = conv_block(ch_in=128, ch_out=64)
128 |
129 | self.Up2 = up_conv(ch_in=64, ch_out=32)
130 | self.Up_conv2 = conv_block(ch_in=64, ch_out=32)
131 |
132 | self.Conv_1x1 = nn.Conv2d(32, output_ch, kernel_size=1, stride=1, padding=0)
133 |
134 | for m in self.modules():
135 | if isinstance(m, nn.Conv2d):
136 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
137 | m.weight.data.normal_(0, math.sqrt(2. / n))
138 | elif isinstance(m, nn.BatchNorm2d):
139 | m.weight.data.fill_(1)
140 | m.bias.data.zero_()
141 | elif isinstance(m, nn.Linear):
142 | m.bias.data.zero_()
143 |
144 | def forward(self, x):
145 | x1 = self.Conv1(x)
146 |
147 | x2 = self.Maxpool(x1)
148 | x2 = self.Conv2(x2)
149 |
150 | x3 = self.Maxpool(x2)
151 | x3 = self.Conv3(x3)
152 |
153 | x4 = self.Maxpool(x3)
154 | x4 = self.Conv4(x4)
155 |
156 | d4 = self.Up4(x4)
157 | d4 = torch.cat((x3, d4), dim=1)
158 | d4 = self.Up_conv4(d4)
159 |
160 | d3 = self.Up3(d4)
161 | d3 = torch.cat((x2, d3), dim=1)
162 | d3 = self.Up_conv3(d3)
163 |
164 | d2 = self.Up2(d3)
165 | d2 = torch.cat((x1, d2), dim=1)
166 | d2 = self.Up_conv2(d2)
167 |
168 | d1 = self.Conv_1x1(d2)
169 |
170 | return d1
171 |
172 |
173 | class U_Net_v3(nn.Module):
174 | def __init__(self, img_ch=3, output_ch=1):
175 | super(U_Net_v3, self).__init__()
176 |
177 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
178 |
179 | self.Conv0 = conv_block(ch_in=img_ch, ch_out=64)
180 | self.Conv1 = conv_block(ch_in=64, ch_out=128)
181 | self.Conv2 = conv_block(ch_in=128, ch_out=256)
182 |
183 | self.Up5 = up_conv(ch_in=256, ch_out=128)
184 | self.Up_conv5 = conv_block(ch_in=256, ch_out=128)
185 |
186 | self.Up4 = up_conv(ch_in=128, ch_out=64)
187 | self.Up_conv4 = conv_block(ch_in=128, ch_out=64)
188 |
189 | self.Up3 = up_conv(ch_in=64, ch_out=32)
190 | self.Up_conv3 = conv_block(ch_in=32, ch_out=32)
191 |
192 | self.Up2 = up_conv(ch_in=32, ch_out=32)
193 | self.Up_conv2 = conv_block(ch_in=32, ch_out=32)
194 |
195 | self.Conv_1x1 = nn.Conv2d(32, output_ch, kernel_size=1, stride=1, padding=0)
196 |
197 | for m in self.modules():
198 | if isinstance(m, nn.Conv2d):
199 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
200 | m.weight.data.normal_(0, math.sqrt(2. / n))
201 | elif isinstance(m, nn.BatchNorm2d):
202 | m.weight.data.fill_(1)
203 | m.bias.data.zero_()
204 | elif isinstance(m, nn.Linear):
205 | m.bias.data.zero_()
206 |
207 | def forward(self, x):
208 | x0 = self.Conv0(x) # 64 channels
209 |
210 | x1 = self.Conv1(x0) # 128 channels
211 | x1 = self.Maxpool(x1) # 1/8 resolution
212 |
213 | x2 = self.Conv2(x1) # 256 channels
214 | x2 = self.Maxpool(x2) # 1/16 resolution
215 |
216 | d4 = self.Up5(x2) # 1/8 resolution
217 | d4 = torch.cat((x1, d4), dim=1)
218 | d4 = self.Up_conv5(d4) # 128 channels
219 |
220 | d3 = self.Up4(d4) # 1/4 resolution
221 | d3 = torch.cat((x0, d3), dim=1)
222 | d3 = self.Up_conv4(d3) # 64 channels
223 |
224 | d2 = self.Up3(d3) # 1/2 resolution
225 | d2 = self.Up_conv3(d2) # 32 channels
226 |
227 | d1 = self.Up2(d2) # 1/2 resolution
228 | d1 = self.Up_conv2(d1) # 32 channels
229 |
230 | d0 = self.Conv_1x1(d1)
231 |
232 | return d0
233 |
234 |
235 | class U_Net_v4(nn.Module):
236 | def __init__(self, img_ch, output_ch):
237 | super(U_Net_v4, self).__init__()
238 |
239 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
240 |
241 | self.Conv1 = conv_block(ch_in=img_ch, ch_out=32)
242 | self.Conv2 = conv_block(ch_in=32, ch_out=64)
243 | self.Conv3 = conv_block(ch_in=64, ch_out=128)
244 |
245 | self.Conv4 = conv_block(ch_in=128, ch_out=128)
246 |
247 | self.Up4 = conv_block(ch_in=128, ch_out=128)
248 | self.Up_conv4 = up_conv(ch_in=256, ch_out=64)
249 |
250 | self.Up3 = conv_block(ch_in=64, ch_out=64)
251 | self.Up_conv3 = up_conv(ch_in=128, ch_out=32)
252 |
253 | self.last_conv = nn.Conv2d(64, output_ch, 1, 1, 0, 1)
254 |
255 | for m in self.modules():
256 | if isinstance(m, nn.Conv2d):
257 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
258 | m.weight.data.normal_(0, math.sqrt(2. / n))
259 | # nn.init.kaiming_normal_(m, mode='fan_in', nonlinearity='relu')
260 | elif isinstance(m, nn.BatchNorm2d):
261 | m.weight.data.fill_(1)
262 | m.bias.data.zero_()
263 | elif isinstance(m, nn.Linear):
264 | m.bias.data.zero_()
265 |
266 | def forward(self, x):
267 | x1 = self.Conv1(x) # 32, 1/4
268 |
269 | x2 = self.Maxpool(x1)
270 | x2 = self.Conv2(x2) # 64, 1/8
271 |
272 | x3 = self.Maxpool(x2)
273 | x3 = self.Conv3(x3) # 128, 1/16
274 |
275 | x4 = self.Conv4(x3) # 128, 1/16
276 |
277 | d4 = self.Up4(x4) # 128, 1/16
278 | d4 = torch.cat((x3, d4), dim=1)
279 | d4 = self.Up_conv4(d4) # 64, 1/8
280 |
281 | d3 = self.Up3(d4) # 64, 1/8
282 | d3 = torch.cat((x2, d3), dim=1)
283 | d3 = self.Up_conv3(d3) # 32, 1/4
284 |
285 | d2 = torch.cat((x1, d3), dim=1)
286 | d2 = self.last_conv(d2)
287 |
288 | return d2
289 |
290 |
291 | class LinearProj(nn.Module):
292 | def __init__(self, in_c, out_c):
293 | super(LinearProj, self).__init__()
294 | self.conv = nn.Sequential(
295 | nn.Conv2d(in_c, out_c, 1, 1, 0, 1),
296 | nn.ReLU(inplace=True),
297 | nn.Conv2d(out_c, out_c, 1, 1, 0, 1))
298 | # self.conv = nn.Conv2d(in_c, out_c, 1, 1, 0, 1)
299 |
300 | def forward(self, x):
301 | x = self.conv(x)
302 | return x
303 |
304 |
305 | if __name__ == '__main__':
306 | a = torch.rand(2, 3, 64, 128).cuda()
307 | net = U_Net_v3(img_ch=3, output_ch=4).cuda()
308 | b = net(a)
309 | print(b.shape)
--------------------------------------------------------------------------------
/networks/Aggregator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.data
4 | from torch.autograd import Variable
5 | import torch.nn.functional as F
6 | import math
7 | from networks.submodule import convbn, convbn_3d, DisparityRegression
8 | from networks.stackhourglass import hourglass_gwcnet, hourglass
9 | import matplotlib.pyplot as plt
10 | import loss_functions as lf
11 |
12 |
13 | def build_cost_volume(left_fea, right_fea, max_disp, cost_type):
14 | if cost_type == 'cor':
15 |
16 | left_fea_norm = F.normalize(left_fea, dim=1)
17 | right_fea_norm = F.normalize(right_fea, dim=1)
18 |
19 | cost = torch.zeros(left_fea.size()[0], 1, max_disp // 4,
20 | left_fea.size()[2], left_fea.size()[3]).cuda()
21 |
22 | for i in range(max_disp // 4):
23 | if i > 0:
24 | cost[:, :, i, :, i:] = (torch.sum(left_fea_norm[:, :, :, i:] * right_fea_norm[:, :, :, :-i],
25 | dim=1, keepdim=True) + 1) / 2
26 | else:
27 | cost[:, :, i, :, :] = (torch.sum(left_fea_norm * right_fea_norm, dim=1, keepdim=True) + 1) / 2
28 |
29 | elif cost_type == 'l2':
30 | cost = torch.zeros(left_fea.size()[0], 1, max_disp // 4,
31 | left_fea.size()[2], left_fea.size()[3]).cuda()
32 |
33 | for i in range(max_disp // 4):
34 | if i > 0:
35 | cost[:, :, i, :, i:] = torch.sqrt(torch.sum(
36 | (left_fea[:, :, :, i:] - right_fea[:, :, :, :-i]) ** 2, dim=1, keepdim=True))
37 |
38 | else:
39 | cost[:, :, i, :, :] = torch.sqrt(torch.sum((left_fea - right_fea) ** 2, dim=1, keepdim=True))
40 |
41 | elif cost_type == 'cat':
42 |
43 | cost = torch.zeros(left_fea.size()[0], left_fea.size()[1] * 2, max_disp // 4,
44 | left_fea.size()[2], left_fea.size()[3]).cuda()
45 |
46 | for i in range(max_disp // 4):
47 | if i > 0:
48 | cost[:, :left_fea.size()[1], i, :, i:] = left_fea[:, :, :, i:]
49 | cost[:, left_fea.size()[1]:, i, :, i:] = right_fea[:, :, :, :-i]
50 | else:
51 | cost[:, :left_fea.size()[1], i, :, :] = left_fea
52 | cost[:, left_fea.size()[1]:, i, :, :] = right_fea
53 |
54 | elif cost_type == 'ncat':
55 |
56 | left_fea = F.normalize(left_fea, dim=1)
57 | right_fea = F.normalize(right_fea, dim=1)
58 |
59 | cost = torch.zeros(left_fea.size()[0], left_fea.size()[1] * 2, max_disp // 4,
60 | left_fea.size()[2], left_fea.size()[3]).cuda()
61 |
62 | for i in range(max_disp // 4):
63 | if i > 0:
64 | cost[:, :left_fea.size()[1], i, :, i:] = left_fea[:, :, :, i:]
65 | cost[:, left_fea.size()[1]:, i, :, i:] = right_fea[:, :, :, :-i]
66 | else:
67 | cost[:, :left_fea.size()[1], i, :, :] = left_fea
68 | cost[:, left_fea.size()[1]:, i, :, :] = right_fea
69 |
70 | cost = cost.contiguous()
71 |
72 | return cost
73 |
74 |
75 | class GwcAggregator(nn.Module):
76 | def __init__(self, maxdisp):
77 | super(GwcAggregator, self).__init__()
78 | self.maxdisp = maxdisp
79 |
80 | self.dres0 = nn.Sequential(convbn_3d(64, 32, 3, 1, 1),
81 | nn.ReLU(inplace=True),
82 | convbn_3d(32, 32, 3, 1, 1),
83 | nn.ReLU(inplace=True))
84 | self.dres1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
85 | nn.ReLU(inplace=True),
86 | convbn_3d(32, 32, 3, 1, 1),
87 | nn.ReLU(inplace=True))
88 |
89 | self.hg1 = hourglass_gwcnet(32)
90 | self.hg2 = hourglass_gwcnet(32)
91 | self.hg3 = hourglass_gwcnet(32)
92 |
93 | self.classify1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
94 | nn.ReLU(inplace=True),
95 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False))
96 | self.classify2 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
97 | nn.ReLU(inplace=True),
98 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False))
99 | self.classify3 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
100 | nn.ReLU(inplace=True),
101 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False))
102 |
103 | for m in self.modules():
104 | if isinstance(m, nn.Conv2d):
105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
106 | m.weight.data.normal_(0, math.sqrt(2. / n))
107 | elif isinstance(m, nn.Conv3d):
108 | n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels
109 | m.weight.data.normal_(0, math.sqrt(2. / n))
110 | elif isinstance(m, nn.BatchNorm2d):
111 | m.weight.data.fill_(1)
112 | m.bias.data.zero_()
113 | elif isinstance(m, nn.BatchNorm3d):
114 | m.weight.data.fill_(1)
115 | m.bias.data.zero_()
116 | elif isinstance(m, nn.Linear):
117 | m.bias.data.zero_()
118 |
119 | def forward(self, left_fea, right_fea, gt_left, gt_right):
120 | cost = build_cost_volume(left_fea, right_fea, self.maxdisp, cost_type='ncat')
121 |
122 | cost0 = self.dres0(cost)
123 | cost1 = self.dres1(cost0) + cost0
124 |
125 | out1 = self.hg1(cost1)
126 | out2 = self.hg2(out1)
127 | out3 = self.hg3(out2)
128 |
129 | win_s = 5
130 |
131 | if self.training:
132 | cost1 = self.classify1(out1)
133 | cost1 = F.interpolate(cost1, scale_factor=4, mode='trilinear', align_corners=True)
134 | cost1 = torch.squeeze(cost1, 1)
135 | distribute1 = F.softmax(cost1, dim=1)
136 | pred1 = DisparityRegression(self.maxdisp, win_size=win_s)(distribute1)
137 |
138 | cost2 = self.classify2(out2)
139 | cost2 = F.interpolate(cost2, scale_factor=4, mode='trilinear', align_corners=True)
140 | cost2 = torch.squeeze(cost2, 1)
141 | distribute2 = F.softmax(cost2, dim=1)
142 | pred2 = DisparityRegression(self.maxdisp, win_size=win_s)(distribute2)
143 |
144 | cost3 = self.classify3(out3)
145 | cost3 = F.interpolate(cost3, scale_factor=4, mode='trilinear', align_corners=True)
146 | cost3 = torch.squeeze(cost3, 1)
147 | distribute3 = F.softmax(cost3, dim=1)
148 | pred3 = DisparityRegression(self.maxdisp, win_size=win_s)(distribute3)
149 |
150 | if self.training:
151 | mask = (gt_left < self.maxdisp) & (gt_left > 0)
152 | loss1 = 0.5 * F.smooth_l1_loss(pred1[mask], gt_left[mask]) + \
153 | 0.7 * F.smooth_l1_loss(pred2[mask], gt_left[mask]) + \
154 | F.smooth_l1_loss(pred3[mask], gt_left[mask])
155 |
156 | gt_distribute = lf.disp2distribute(gt_left, self.maxdisp, b=2)
157 | loss2 = 0.5 * lf.CEloss(gt_left, self.maxdisp, gt_distribute, distribute1) + \
158 | 0.7 * lf.CEloss(gt_left, self.maxdisp, gt_distribute, distribute2) + \
159 | lf.CEloss(gt_left, self.maxdisp, gt_distribute, distribute3)
160 |
161 | loss3 = lf.FeatureSimilarityLoss(self.maxdisp)(left_fea, right_fea, gt_left, gt_right)
162 |
163 | return loss1, loss2, loss3
164 |
165 | else:
166 | return pred3
167 |
168 |
169 | class PSMAggregator(nn.Module):
170 | def __init__(self, maxdisp, udc):
171 | super(PSMAggregator, self).__init__()
172 | self.maxdisp = maxdisp
173 | self.udc = udc
174 |
175 | self.dres0 = nn.Sequential(convbn_3d(1, 32, 3, 1, 1),
176 | nn.ReLU(inplace=True),
177 | convbn_3d(32, 32, 3, 1, 1),
178 | nn.ReLU(inplace=True))
179 | self.dres1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
180 | nn.ReLU(inplace=True),
181 | convbn_3d(32, 32, 3, 1, 1),
182 | nn.ReLU(inplace=True))
183 |
184 | self.hg1 = hourglass(32)
185 | self.hg2 = hourglass(32)
186 | self.hg3 = hourglass(32)
187 |
188 | self.classify1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
189 | nn.ReLU(inplace=True),
190 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False))
191 | self.classify2 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
192 | nn.ReLU(inplace=True),
193 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False))
194 | self.classify3 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
195 | nn.ReLU(inplace=True),
196 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False))
197 |
198 | for m in self.modules():
199 | if isinstance(m, nn.Conv2d):
200 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
201 | m.weight.data.normal_(0, math.sqrt(2. / n))
202 | elif isinstance(m, nn.Conv3d):
203 | n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels
204 | m.weight.data.normal_(0, math.sqrt(2. / n))
205 | elif isinstance(m, nn.BatchNorm2d):
206 | m.weight.data.fill_(1)
207 | m.bias.data.zero_()
208 | elif isinstance(m, nn.BatchNorm3d):
209 | m.weight.data.fill_(1)
210 | m.bias.data.zero_()
211 | elif isinstance(m, nn.Linear):
212 | m.bias.data.zero_()
213 |
214 | def forward(self, left_fea, right_fea, gt_left, training):
215 | cost = build_cost_volume(left_fea, right_fea, self.maxdisp, cost_type='cor')
216 |
217 | cost0 = self.dres0(cost)
218 | cost1 = self.dres1(cost0) + cost0
219 |
220 | out1, pre1, post1 = self.hg1(cost1, None, None)
221 | out1 = out1+cost0
222 |
223 | out2, pre2, post2 = self.hg2(out1, pre1, post1)
224 | out2 = out2+cost0
225 |
226 | out3, pre3, post3 = self.hg3(out2, pre1, post2)
227 | out3 = out3+cost0
228 |
229 | cost1 = self.classify1(out1)
230 | cost2 = self.classify2(out2) + cost1
231 | cost3 = self.classify3(out3) + cost2
232 |
233 | if self.udc:
234 | win_s = 5
235 | else:
236 | win_s = 0
237 |
238 | if self.training:
239 | cost1 = F.interpolate(cost1, scale_factor=4, mode='trilinear', align_corners=True)
240 | cost1 = torch.squeeze(cost1, 1)
241 | distribute1 = F.softmax(cost1, dim=1)
242 | pred1 = DisparityRegression(self.maxdisp, win_size=win_s)(distribute1)
243 |
244 | cost2 = F.interpolate(cost2, scale_factor=4, mode='trilinear', align_corners=True)
245 | cost2 = torch.squeeze(cost2, 1)
246 | distribute2 = F.softmax(cost2, dim=1)
247 | pred2 = DisparityRegression(self.maxdisp, win_size=win_s)(distribute2)
248 |
249 | cost3 = F.interpolate(cost3, scale_factor=4, mode='trilinear', align_corners=True)
250 | cost3 = torch.squeeze(cost3, 1)
251 | distribute3 = F.softmax(cost3, dim=1)
252 | pred3 = DisparityRegression(self.maxdisp, win_size=win_s)(distribute3)
253 |
254 | if self.training:
255 | mask = (gt_left < self.maxdisp) & (gt_left > 0)
256 |
257 | loss1 = 0.5 * F.smooth_l1_loss(pred1[mask], gt_left[mask]) + \
258 | 0.7 * F.smooth_l1_loss(pred2[mask], gt_left[mask]) + \
259 | F.smooth_l1_loss(pred3[mask], gt_left[mask])
260 |
261 | gt_distribute = lf.disp2distribute(gt_left, self.maxdisp, b=2)
262 | loss2 = 0.5 * lf.CEloss(gt_left, self.maxdisp, gt_distribute, distribute1) + \
263 | 0.7 * lf.CEloss(gt_left, self.maxdisp, gt_distribute, distribute2) + \
264 | lf.CEloss(gt_left, self.maxdisp, gt_distribute, distribute3)
265 | return loss1, loss2
266 |
267 | else:
268 | if training:
269 | mask = (gt_left < self.maxdisp) & (gt_left > 0)
270 | loss1 = F.smooth_l1_loss(pred3[mask], gt_left[mask])
271 | # loss2 = loss1
272 | gt_distribute = lf.disp2distribute(gt_left, self.maxdisp, b=2)
273 | loss2 = lf.CEloss(gt_left, self.maxdisp, gt_distribute, distribute3)
274 | return loss1, loss2
275 |
276 | else:
277 | return pred3
278 |
--------------------------------------------------------------------------------