├── .gitignore ├── Data ├── OurDataSet.py ├── SceneFlow.py ├── SceneFlow_helper.py ├── SceneFlow_helper1.py ├── __init__.py ├── pfm_helper.py └── val.npy ├── LICENSE ├── Losses ├── __init__.py └── supervise.py ├── Metrics ├── __init__.py └── metrics.py ├── Models ├── ActiveStereoNet.py ├── StereoNet.py ├── __init__.py └── blocks.py ├── Options ├── __init__.py └── example.json ├── Sovlers ├── __init__.py ├── solver_test.py └── solver_train.py ├── main.py └── readme.md /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | __pycache__ 3 | 4 | code_test.py 5 | test.py 6 | 7 | *.png 8 | Experiments -------------------------------------------------------------------------------- /Data/OurDataSet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | 4 | IMG_EXTENSIONS = [ 5 | '.jpg', '.JPG', '.jpeg', '.JPEG', 6 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 7 | ] 8 | 9 | -------------------------------------------------------------------------------- /Data/SceneFlow.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | 8 | from .SceneFlow_helper import read_sceneflow 9 | from .pfm_helper import read_pfm 10 | 11 | import pdb 12 | 13 | class SceneFlowDataset(Dataset): 14 | 15 | def __init__(self, data_root, npy_root, val_split, test_split, transform, phase): 16 | 17 | super(SceneFlowDataset, self).__init__() 18 | 19 | self.data_root = data_root 20 | self.npy_root = npy_root 21 | self.phase = phase 22 | self.val_split = val_split 23 | self.test_split = test_split 24 | self.transform = transform 25 | 26 | self.left_imgs, self.right_imgs, self.disps, self.test_left_imgs, self.test_right_imgs, self.test_disps, self.disps_R, self.test_disps_R = read_sceneflow(self.data_root) 27 | #pdb.set_trace() 28 | 29 | assert len(self.left_imgs) == len(self.right_imgs) == len(self.disps) == len(self.disps_R), 'Invalid training dataset!' 30 | assert len(self.test_left_imgs) == len(self.test_right_imgs) == len(self.test_disps) == len(self.test_disps_R), 'Invalid testing dataset!' 31 | 32 | #total_data_num = len(self.left_imgs) 33 | 34 | #self.nb_train = int((1 - self.val_split - self.test_split) * total_data_num) 35 | #self.nb_val = int(self.val_split * total_data_num) 36 | #self.nb_test = int(self.test_split * total_data_num) 37 | 38 | test_data_num = len(self.test_left_imgs) 39 | 40 | self.nb_train = len(self.left_imgs) 41 | self.nb_val = int(self.val_split * test_data_num) 42 | self.nb_test = test_data_num 43 | 44 | train_npy = os.path.join(self.npy_root, 'train.npy') 45 | val_npy = os.path.join(self.npy_root, 'val.npy') 46 | test_npy = os.path.join(self.npy_root, 'test.npy') 47 | 48 | if os.path.exists(train_npy) and os.path.exists(val_npy) and os.path.exists(test_npy): 49 | #self.train_list = np.load(train_npy) 50 | self.val_list = np.load(val_npy) 51 | #pdb.set_trace() 52 | #self.test_list = np.load(test_npy) 53 | 54 | 55 | else: 56 | #total_idcs = np.random.permutation(total_data_num) 57 | #self.train_list = total_idcs[0:self.nb_train] 58 | #self.val_list = total_idcs[self.nb_train:self.nb_train + self.nb_val] 59 | #self.test_list = total_idcs[self.nb_train + self.nb_val:] 60 | 61 | test_idcs = np.random.permutation(test_data_num) 62 | self.val_list = test_idcs[0:self.nb_val] 63 | 64 | 65 | #np.save(train_npy, self.train_list) 66 | np.save(val_npy, self.val_list) 67 | #np.save(test_npy, self.test_list) 68 | 69 | def __len__(self): 70 | 71 | if self.phase == 'train': 72 | return self.nb_train 73 | elif self.phase == 'val': 74 | return self.nb_val 75 | elif self.phase == 'test': 76 | return self.nb_test 77 | 78 | 79 | def __getitem__(self, index): 80 | 81 | if self.phase == 'train': 82 | left_image = self._read_image(self.left_imgs[index]) 83 | right_image = self._read_image(self.right_imgs[index]) 84 | left_disp, scale = read_pfm(self.disps[index]) 85 | right_disp, scale = read_pfm(self.disps_R[index]) 86 | 87 | elif self.phase == 'val': 88 | index = self.val_list[index] 89 | left_image = self._read_image(self.test_left_imgs[index]) 90 | right_image = self._read_image(self.test_right_imgs[index]) 91 | left_disp, scale = read_pfm(self.test_disps[index]) 92 | right_disp, scale = read_pfm(self.test_disps_R[index]) 93 | 94 | elif self.phase == 'test': 95 | left_image = self._read_image(self.test_left_imgs[index]) 96 | right_image = self._read_image(self.test_right_imgs[index]) 97 | left_disp, scale = read_pfm(self.test_disps[index]) 98 | right_disp, scale = read_pfm(self.test_disps_R[index]) 99 | 100 | if self.transform: 101 | left_image = self.transform(left_image) 102 | right_image = self.transform(right_image) 103 | 104 | left_disp = torch.Tensor(left_disp) 105 | right_disp = torch.Tensor(right_disp) 106 | return left_image, right_image, left_disp, right_disp 107 | 108 | ''' 109 | def __getitem__(self, index): 110 | 111 | if self.phase == 'train': 112 | index = self.train_list[index] 113 | elif self.phase == 'val': 114 | index = self.val_list[index] 115 | elif self.phase == 'test': 116 | index = self.test_list[index] 117 | 118 | left_image = self._read_image(self.left_imgs[index]) 119 | right_image = self._read_image(self.right_imgs[index]) 120 | left_disp, scale = read_pfm(self.disps[index]) 121 | 122 | if self.transform: 123 | left_image = self.transform(left_image) 124 | right_image = self.transform(right_image) 125 | 126 | left_disp = torch.Tensor(left_disp) 127 | 128 | return left_image, right_image, left_disp, scale 129 | ''' 130 | 131 | def _read_image(self, filename): 132 | 133 | attempt = True 134 | while attempt: 135 | try: 136 | with open(filename, 'rb') as f: 137 | img = Image.open(f).convert('RGB') 138 | attempt = False 139 | except IOError as e: 140 | print('[IOError] {}, keep trying...'.format(e)) 141 | attempt = True 142 | return img -------------------------------------------------------------------------------- /Data/SceneFlow_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | 4 | IMG_EXTENSIONS = [ 5 | '.jpg', '.JPG', '.jpeg', '.JPEG', 6 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 7 | ] 8 | 9 | def is_image_file(filename): 10 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 11 | 12 | def read_sceneflow(filepath): 13 | 14 | classes = [d for d in os.listdir(filepath) if os.path.isdir(os.path.join(filepath, d))] 15 | 16 | 17 | image = [img for img in classes if img.find('frames_cleanpass') > -1] 18 | disp = [dsp for dsp in classes if dsp.find('disparity') > -1] 19 | 20 | monkaa_path = filepath + [x for x in image if 'monkaa' in x][0] 21 | monkaa_disp = filepath + [x for x in disp if 'monkaa' in x][0] 22 | monkaa_dir = os.listdir(monkaa_path) 23 | 24 | all_left_img = [] 25 | all_right_img = [] 26 | all_left_disp = [] 27 | all_right_disp = [] 28 | test_left_img = [] 29 | test_right_img = [] 30 | test_left_disp = [] 31 | test_right_disp = [] 32 | 33 | for dd in monkaa_dir: 34 | for im in os.listdir(monkaa_path + '/' + dd + '/left/'): 35 | if is_image_file(monkaa_path + '/' + dd + '/left/' + im): 36 | all_left_img.append(monkaa_path + '/' + dd + '/left/' + im) 37 | all_left_disp.append(monkaa_disp + '/' + dd + '/left/' + im.split(".")[0] + '.pfm') 38 | 39 | for im in os.listdir(monkaa_path + '/' + dd + '/right/'): 40 | if is_image_file(monkaa_path + '/' + dd + '/right/' + im): 41 | all_right_img.append(monkaa_path + '/' + dd + '/right/' + im) 42 | all_right_disp.append(monkaa_disp + '/' + dd + '/right/' + im.split(".")[0] + '.pfm') 43 | 44 | flying_path = filepath + [x for x in image if x == 'frames_cleanpass'][0] 45 | flying_disp = filepath + [x for x in disp if x == 'frames_disparity'][0] 46 | flying_dir = flying_path + '/TRAIN/' 47 | subdir = ['A','B','C'] 48 | 49 | for ss in subdir: 50 | flying = os.listdir(flying_dir + ss) 51 | 52 | for ff in flying: 53 | 54 | imm_l = os.listdir(flying_dir + ss + '/' + ff + '/left/') 55 | 56 | for im in imm_l: 57 | if is_image_file(flying_dir + ss + '/' + ff + '/left/' + im): 58 | all_left_img.append(flying_dir + ss + '/' + ff + '/left/' + im) 59 | 60 | all_left_disp.append(flying_disp + '/TRAIN/' + ss + '/' + ff + '/left/' + im.split(".")[0] + '.pfm') 61 | 62 | if is_image_file(flying_dir + ss + '/' + ff + '/right/' + im): 63 | all_right_img.append(flying_dir + ss + '/' + ff + '/right/' + im) 64 | 65 | all_right_disp.append(flying_disp + '/TRAIN/' + ss + '/' + ff + '/right/' + im.split(".")[0] + '.pfm') 66 | 67 | 68 | flying_dir = flying_path + '/TEST/' 69 | 70 | subdir = ['A','B','C'] 71 | 72 | for ss in subdir: 73 | flying = os.listdir(flying_dir + ss) 74 | 75 | for ff in flying: 76 | 77 | imm_l = os.listdir(flying_dir + ss + '/' + ff + '/left/') 78 | 79 | for im in imm_l: 80 | if is_image_file(flying_dir + ss + '/' + ff + '/left/' + im): 81 | test_left_img.append(flying_dir + ss + '/' + ff + '/left/' + im) 82 | 83 | test_left_disp.append(flying_disp + '/TEST/' + ss + '/' + ff + '/left/' + im.split(".")[0] + '.pfm') 84 | 85 | if is_image_file(flying_dir + ss + '/' + ff + '/right/' + im): 86 | test_right_img.append(flying_dir + ss + '/' + ff + '/right/' + im) 87 | 88 | test_right_disp.append(flying_disp + '/TEST/' + ss + '/' + ff + '/right/' + im.split(".")[0] + '.pfm') 89 | 90 | 91 | driving_dir = filepath + [x for x in image if 'driving' in x][0] + '/' 92 | driving_disp = filepath + [x for x in disp if 'driving' in x][0] 93 | 94 | subdir1 = ['35mm_focallength', '15mm_focallength'] 95 | #subdir1 = ['15mm_focallength'] 96 | subdir2 = ['scene_backwards', 'scene_forwards'] 97 | #subdir2= ['scene_backwards'] 98 | subdir3 = ['fast', 'slow'] 99 | #subdir3 = ['fast'] 100 | 101 | for i in subdir1: 102 | for j in subdir2: 103 | for k in subdir3: 104 | imm_l = os.listdir(driving_dir + i + '/' + j + '/' + k + '/left/') 105 | for im in imm_l: 106 | if is_image_file(driving_dir + i + '/' + j + '/' + k + '/left/' + im): 107 | all_left_img.append(driving_dir + i + '/' + j + '/' + k + '/left/' + im) 108 | 109 | all_left_disp.append(driving_disp + '/' + i + '/' + j + '/' + k + '/left/' + im.split(".")[0] + '.pfm') 110 | 111 | if is_image_file(driving_dir + i + '/' + j + '/' + k + '/right/' + im): 112 | all_right_img.append(driving_dir + i + '/' + j + '/' + k + '/right/' + im) 113 | 114 | all_right_disp.append(driving_disp + '/' + i + '/' + j + '/' + k + '/right/' + im.split(".")[0] + '.pfm') 115 | 116 | #pdb.set_trace() 117 | return all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp, all_right_disp, test_right_disp 118 | -------------------------------------------------------------------------------- /Data/SceneFlow_helper1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | 4 | IMG_EXTENSIONS = [ 5 | '.jpg', '.JPG', '.jpeg', '.JPEG', 6 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 7 | ] 8 | 9 | def is_image_file(filename): 10 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 11 | 12 | def read_sceneflow(filepath): 13 | 14 | classes = [d for d in os.listdir(filepath) if os.path.isdir(os.path.join(filepath, d))] 15 | 16 | 17 | image = [img for img in classes if img.find('frames_cleanpass') > -1] 18 | disp = [dsp for dsp in classes if dsp.find('disparity') > -1] 19 | 20 | monkaa_path = filepath + [x for x in image if 'monkaa' in x][0] 21 | monkaa_disp = filepath + [x for x in disp if 'monkaa' in x][0] 22 | monkaa_dir = os.listdir(monkaa_path) 23 | 24 | all_left_img = [] 25 | all_right_img = [] 26 | all_left_disp = [] 27 | test_left_img = [] 28 | test_right_img = [] 29 | test_left_disp = [] 30 | 31 | for dd in monkaa_dir: 32 | for im in os.listdir(monkaa_path + '/' + dd + '/left/'): 33 | if is_image_file(monkaa_path + '/' + dd + '/left/' + im): 34 | all_left_img.append(monkaa_path + '/' + dd + '/left/' + im) 35 | all_left_disp.append(monkaa_disp + '/' + dd + '/left/' + im.split(".")[0] + '.pfm') 36 | 37 | for im in os.listdir(monkaa_path + '/' + dd + '/right/'): 38 | if is_image_file(monkaa_path + '/' + dd + '/right/' + im): 39 | all_right_img.append(monkaa_path + '/' + dd + '/right/' + im) 40 | 41 | flying_path = filepath + [x for x in image if x == 'frames_cleanpass'][0] 42 | flying_disp = filepath + [x for x in disp if x == 'frames_disparity'][0] 43 | flying_dir = flying_path + '/TRAIN/' 44 | subdir = ['A','B','C'] 45 | 46 | for ss in subdir: 47 | flying = os.listdir(flying_dir + ss) 48 | 49 | for ff in flying: 50 | 51 | imm_l = os.listdir(flying_dir + ss + '/' + ff + '/left/') 52 | 53 | for im in imm_l: 54 | if is_image_file(flying_dir + ss + '/' + ff + '/left/' + im): 55 | all_left_img.append(flying_dir + ss + '/' + ff + '/left/' + im) 56 | 57 | all_left_disp.append(flying_disp + '/TRAIN/' + ss + '/' + ff + '/left/' + im.split(".")[0] + '.pfm') 58 | 59 | if is_image_file(flying_dir + ss + '/' + ff + '/right/' + im): 60 | all_right_img.append(flying_dir + ss + '/' + ff + '/right/' + im) 61 | 62 | flying_dir = flying_path + '/TEST/' 63 | 64 | subdir = ['A','B','C'] 65 | 66 | for ss in subdir: 67 | flying = os.listdir(flying_dir + ss) 68 | 69 | for ff in flying: 70 | 71 | imm_l = os.listdir(flying_dir + ss + '/' + ff + '/left/') 72 | 73 | for im in imm_l: 74 | if is_image_file(flying_dir + ss + '/' + ff + '/left/' + im): 75 | test_left_img.append(flying_dir + ss + '/' + ff + '/left/' + im) 76 | 77 | test_left_disp.append(flying_disp + '/TEST/' + ss + '/' + ff + '/left/' + im.split(".")[0] + '.pfm') 78 | 79 | if is_image_file(flying_dir + ss + '/' + ff + '/right/' + im): 80 | test_right_img.append(flying_dir + ss + '/' + ff + '/right/' + im) 81 | 82 | driving_dir = filepath + [x for x in image if 'driving' in x][0] + '/' 83 | driving_disp = filepath + [x for x in disp if 'driving' in x][0] 84 | 85 | subdir1 = ['35mm_focallength', '15mm_focallength'] 86 | #subdir1 = ['15mm_focallength'] 87 | subdir2 = ['scene_backwards', 'scene_forwards'] 88 | #subdir2= ['scene_backwards'] 89 | subdir3 = ['fast', 'slow'] 90 | #subdir3 = ['fast'] 91 | 92 | for i in subdir1: 93 | for j in subdir2: 94 | for k in subdir3: 95 | imm_l = os.listdir(driving_dir + i + '/' + j + '/' + k + '/left/') 96 | for im in imm_l: 97 | if is_image_file(driving_dir + i + '/' + j + '/' + k + '/left/' + im): 98 | all_left_img.append(driving_dir + i + '/' + j + '/' + k + '/left/' + im) 99 | 100 | all_left_disp.append(driving_disp + '/' + i + '/' + j + '/' + k + '/left/' + im.split(".")[0] + '.pfm') 101 | 102 | if is_image_file(driving_dir + i + '/' + j + '/' + k + '/right/' + im): 103 | all_right_img.append(driving_dir + i + '/' + j + '/' + k + '/right/' + im) 104 | 105 | 106 | return all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp 107 | -------------------------------------------------------------------------------- /Data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as T 3 | 4 | from torch.utils.data import DataLoader 5 | from .SceneFlow import SceneFlowDataset 6 | 7 | def get_loader(config): 8 | 9 | dset = config['dataset_name'].lower() 10 | if dset == 'sceneflow': 11 | return get_scene_flow_loader(config) 12 | else: 13 | raise NotImplementedError('Dataset [{:s}] is not supported.'.format(dset)) 14 | 15 | def get_scene_flow_loader(config): 16 | 17 | cfg_mode = config['mode'].lower() 18 | 19 | if cfg_mode == 'train': 20 | train_loader = DataLoader( 21 | create_scene_flow_dataset(config['data'], 'train'), 22 | batch_size=config['solver']['batch_size'], 23 | shuffle=True, 24 | pin_memory=True, 25 | drop_last=True 26 | ) 27 | val_loader = DataLoader( 28 | create_scene_flow_dataset(config['data'], 'val'), 29 | batch_size=config['solver']['batch_size'], 30 | shuffle=False, 31 | pin_memory=True, 32 | drop_last=False 33 | ) 34 | return train_loader, val_loader 35 | elif cfg_mode == 'test': 36 | test_loader = DataLoader( 37 | create_scene_flow_dataset(config['data'], 'test'), 38 | batch_size=config['solver']['batch_size'], 39 | shuffle=False, 40 | pin_memory=True, 41 | drop_last=False 42 | ) 43 | return test_loader 44 | else: 45 | raise NotImplementedError('Mode [{:s}] is not supported.'.format(cfg_mode)) 46 | 47 | def create_scene_flow_dataset(cfg_data, mode): 48 | 49 | data_root = cfg_data['data_root'] 50 | npy_root = cfg_data['npy_root'] 51 | test_split = cfg_data['test_split'] 52 | val_split = cfg_data['val_split'] 53 | 54 | transform = T.Compose([ 55 | T.ToTensor(), 56 | T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 57 | ]) 58 | 59 | return SceneFlowDataset(data_root, npy_root, val_split, test_split, transform, mode) -------------------------------------------------------------------------------- /Data/pfm_helper.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | 4 | import pdb 5 | 6 | def read_pfm(file): 7 | 8 | file = open(file, 'rb') 9 | 10 | color = None 11 | width = None 12 | height = None 13 | scale = None 14 | endian = None 15 | 16 | header = file.readline().rstrip() 17 | header = str(header, 'utf-8') 18 | 19 | if header == 'PF': 20 | color = True 21 | elif header == 'Pf': 22 | color = False 23 | else: 24 | raise Exception('Not a PFM file.') 25 | 26 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', str(file.readline(), 'utf-8')) 27 | 28 | if dim_match: 29 | width, height = map(int, dim_match.groups()) 30 | else: 31 | raise Exception('Malformed PFM header.') 32 | 33 | scale = float(file.readline().rstrip()) 34 | if scale < 0: # little-endian 35 | endian = '<' 36 | scale = -scale 37 | else: 38 | endian = '>' # big-endian 39 | 40 | data = np.fromfile(file, endian + 'f') 41 | shape = (height, width, 3) if color else (height, width) 42 | 43 | data = np.reshape(data, shape) 44 | data = np.flipud(data) 45 | data = data[np.newaxis,:,:].copy() 46 | return data, scale 47 | 48 | -------------------------------------------------------------------------------- /Data/val.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjc16/ActiveStereoNet/5a366d7346d6ae5ca3420ec7f966fd0ac76e85b3/Data/val.npy -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 linjc16 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .supervise import * 2 | 3 | def get_losses(name, **kwargs): 4 | 5 | name = name.lower() 6 | if name == 'rhloss': 7 | loss = RHLoss(**kwargs) 8 | elif name == 'xtloss': 9 | loss = XTLoss(**kwargs) 10 | else: 11 | raise NotImplementedError('Loss [{:s}] is not supported.'.format(name)) 12 | 13 | return loss 14 | -------------------------------------------------------------------------------- /Losses/supervise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import pdb 5 | 6 | class RHLoss(nn.Module): 7 | 8 | def __init__(self, max_disp): 9 | 10 | super(RHLoss, self).__init__() 11 | self.max_disp = max_disp 12 | self.crit = nn.SmoothL1Loss(reduction='mean') 13 | 14 | def forward(self, output, target): 15 | 16 | #mask = (target < self.max_disp).float() 17 | #output *= mask 18 | #target *= mask 19 | mask = target < self.max_disp 20 | mask.detach_() 21 | 22 | loss = self.crit(output[mask], target[mask]) 23 | 24 | return loss 25 | 26 | 27 | 28 | class XTLoss(nn.Module): 29 | ''' 30 | Args: 31 | left_img right_img: N * C * H * W, 32 | dispmap : N * H * W 33 | ''' 34 | def __init__(self, max_disp): 35 | super(XTLoss, self).__init__() 36 | self.max_disp = max_disp 37 | self.theta = torch.Tensor( 38 | [[1, 0, 0], # 控制左右,-右,+左 39 | [0, 1, 0]] # 控制上下,-下,+上 40 | ) 41 | self.inplanes = 3 42 | self.outplanes = 3 43 | 44 | 45 | 46 | 47 | def forward(self, left_img, right_img, dispmap): 48 | 49 | n, c, h, w = left_img.shape 50 | 51 | #pdb.set_trace() 52 | theta = self.theta.repeat(left_img.size()[0], 1, 1) 53 | 54 | 55 | grid = F.affine_grid(theta, left_img.size()) 56 | grid = grid.cuda() 57 | 58 | dispmap_norm = dispmap * 2 / w 59 | dispmap_norm = dispmap_norm.cuda() 60 | #pdb.set_trace() 61 | dispmap_norm = dispmap_norm.squeeze(1).unsqueeze(3) 62 | dispmap_norm = torch.cat((dispmap_norm, torch.zeros(dispmap_norm.size()).cuda()), dim=3) 63 | 64 | grid -= dispmap_norm 65 | 66 | recon_img = F.grid_sample(right_img, grid) 67 | 68 | #pdb.set_trace() 69 | recon_img_LCN, _, _ = self.LCN(recon_img, 9) 70 | 71 | left_img_LCN, _, left_std_local = self.LCN(left_img, 9) 72 | 73 | #pdb.set_trace() 74 | losses = torch.abs(((left_img_LCN - recon_img_LCN) * left_std_local)) 75 | 76 | #pdb.set_trace() 77 | losses = self.ASW(left_img, losses, 12, 2) 78 | 79 | return losses 80 | 81 | 82 | def LCN(self, img, kSize): 83 | ''' 84 | Args: 85 | img : N * C * H * W 86 | kSize : 9 * 9 87 | ''' 88 | 89 | w = torch.ones((self.outplanes, self.inplanes, kSize, kSize)).cuda() / (kSize * kSize) 90 | mean_local = F.conv2d(input=img, weight=w, padding=kSize // 2) 91 | 92 | mean_square_local = F.conv2d(input=img ** 2, weight=w, padding=kSize // 2) 93 | std_local = (mean_square_local - mean_local ** 2) * (kSize ** 2) / (kSize ** 2 - 1) 94 | 95 | epsilon = 1e-6 96 | 97 | return (img - mean_local) / (std_local + epsilon), mean_local, std_local 98 | 99 | 100 | def ASW(self, img, Cost, kSize, sigma_omega): 101 | 102 | #pdb.set_trace() 103 | weightGraph = torch.zeros(img.shape, requires_grad=False).cuda() 104 | CostASW = torch.zeros(Cost.shape, dtype=torch.float, requires_grad=True).cuda() 105 | 106 | pad_len = kSize // 2 107 | img = F.pad(img, [pad_len] * 4) 108 | Cost = F.pad(Cost, [pad_len] * 4) 109 | n, c, h, w = img.shape 110 | #pdb.set_trace() 111 | 112 | 113 | 114 | for i in range(kSize): 115 | for j in range(kSize): 116 | tempGraph = torch.abs(img[:, :, pad_len : h - pad_len, pad_len : w - pad_len] - img[:, :, i:i + h - pad_len * 2, j:j + w - pad_len * 2]) 117 | tempGraph = torch.exp(-tempGraph / sigma_omega) 118 | weightGraph += tempGraph 119 | CostASW += tempGraph * Cost[:, :, i:i + h - pad_len * 2, j:j + w - pad_len * 2] 120 | 121 | CostASW = CostASW / weightGraph 122 | 123 | return CostASW.mean() 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /Metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjc16/ActiveStereoNet/5a366d7346d6ae5ca3420ec7f966fd0ac76e85b3/Metrics/__init__.py -------------------------------------------------------------------------------- /Metrics/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import pdb 6 | 7 | def epe_metric(target, output, maxdisp): 8 | mask = (target < maxdisp).float() 9 | 10 | target *= mask 11 | output *= mask 12 | 13 | return torch.abs(target - output).mean() 14 | 15 | #def epe_metric(target, output, maxdisp): 16 | # mask = target < maxdisp 17 | # return torch.abs(output[mask] - target[mask]).mean() 18 | 19 | 20 | #def tripe_metric(target, output, maxdisp): 21 | #mask = (target < maxdisp).float() 22 | 23 | #target *= mask 24 | #output *= mask 25 | 26 | #delta = torch.abs(target - output) 27 | #gt3 = (delta > 3.0).float() 28 | #eps = 1e-7 29 | #return gt3.sum() / (delta.numel() + eps) 30 | 31 | def tripe_metric(target, output, maxdisp): 32 | #pdb.set_trace() 33 | delta = torch.abs(target - output) 34 | correct = ((delta < 3) | torch.lt(delta, target * 0.05)) 35 | eps = 1e-7 36 | return 1 - (float(torch.sum(correct))/(delta.numel() + eps)) 37 | -------------------------------------------------------------------------------- /Models/ActiveStereoNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import numpy as np 6 | import pdb 7 | from .blocks import * 8 | 9 | import pdb 10 | 11 | class SiameseTower(nn.Module): 12 | def __init__(self, scale_factor): 13 | super(SiameseTower, self).__init__() 14 | 15 | self.conv1 = conv_block(nc_in=3, nc_out=32, k=3, s=1, norm=None, act=None) 16 | res_blocks = [ResBlock(32, 32, 3, 1, 1)] * 3 17 | self.res_blocks = nn.Sequential(*res_blocks) 18 | convblocks = [conv_block(32, 32, k=3, s=2, norm='bn', act='lrelu')] * int(math.log2(scale_factor)) 19 | self.conv_blocks = nn.Sequential(*convblocks) 20 | self.conv2 = conv_block(nc_in=32, nc_out=32, k=3, s=1, norm=None, act=None) 21 | 22 | def forward(self, x): 23 | 24 | #pdb.set_trace() 25 | out = self.conv1(x) 26 | out = self.res_blocks(out) 27 | out = self.conv_blocks(out) 28 | out = self.conv2(out) 29 | 30 | return out 31 | 32 | class CoarseNet(nn.Module): 33 | def __init__(self, maxdisp, scale_factor, img_shape): 34 | super(CoarseNet, self).__init__() 35 | self.maxdisp = maxdisp 36 | self.scale_factor = scale_factor 37 | self.img_shape = img_shape 38 | 39 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 40 | 41 | self.conv3d_1 = conv3d_block(64, 32, 3, 1, norm='bn', act='lrelu') 42 | self.conv3d_2 = conv3d_block(32, 32, 3, 1, norm='bn', act='lrelu') 43 | self.conv3d_3 = conv3d_block(32, 32, 3, 1, norm='bn', act='lrelu') 44 | self.conv3d_4 = conv3d_block(32, 32, 3, 1, norm='bn', act='lrelu') 45 | 46 | self.conv3d_5 = conv3d_block(32, 1, 3, 1, norm=None, act=None) 47 | self.disp_reg = DisparityRegression(self.maxdisp) 48 | 49 | def costVolume(self, refimg_fea, targetimg_fea, views): 50 | #Cost Volume 51 | cost = torch.zeros(refimg_fea.size()[0], refimg_fea.size()[1]*2, self.maxdisp//self.scale_factor, refimg_fea.size()[2], refimg_fea.size()[3]).cuda() 52 | views = views.lower() 53 | if views == 'left': 54 | for i in range(self.maxdisp//self.scale_factor): 55 | if i > 0: 56 | cost[:, :refimg_fea.size()[1], i, :, i:] = refimg_fea[:,:,:,i:] 57 | cost[:, refimg_fea.size()[1]:, i, :, i:] = targetimg_fea[:,:,:,:-i] 58 | else: 59 | cost[:, :refimg_fea.size()[1], i, :,:] = refimg_fea 60 | cost[:, refimg_fea.size()[1]:, i, :,:] = targetimg_fea 61 | elif views == 'right': 62 | for i in range(self.maxdisp // self.scale_factor): 63 | if i > 0: 64 | cost[:, :refimg_fea.size()[1], i, :, :-i] = refimg_fea[:,:,:,i:] 65 | cost[:, refimg_fea.size()[1]:, i, :, :-i] = targetimg_fea[:,:,:,:-i] 66 | else: 67 | cost[:, :refimg_fea.size()[1], i, :,:] = refimg_fea 68 | cost[:, refimg_fea.size()[1]:, i, :,:] = targetimg_fea 69 | return cost 70 | 71 | def Coarsepred(self, cost): 72 | #pdb.set_trace() 73 | cost = self.conv3d_1(cost) 74 | cost = self.conv3d_2(cost) + cost 75 | cost = self.conv3d_3(cost) + cost 76 | cost = self.conv3d_4(cost) + cost 77 | 78 | cost = self.conv3d_5(cost) 79 | #pdb.set_trace() 80 | cost = F.interpolate(cost, size=[self.maxdisp, self.img_shape[1], self.img_shape[0]], mode='trilinear', align_corners=False) 81 | #pdb.set_trace() 82 | pred = cost.softmax(dim=2).squeeze(dim=1) 83 | pred = self.disp_reg(pred) 84 | 85 | return pred 86 | 87 | def forward(self, refimg_fea, targetimg_fea): 88 | ''' 89 | Args: 90 | refimg_fea: output of SiameseTower for a left image 91 | targetimg_fea: output of SiameseTower for the right image 92 | 93 | ''' 94 | cost_left = self.costVolume(refimg_fea, targetimg_fea, 'left') 95 | #cost_right = self.costVolume(refimg_fea, targetimg_fea, 'right') 96 | 97 | pred_left = self.Coarsepred(cost_left) 98 | #pred_right = self.Coarsepred(cost_right) 99 | 100 | return pred_left#, pred_right 101 | 102 | 103 | 104 | 105 | class RefineNet(nn.Module): 106 | def __init__(self): 107 | super(RefineNet, self).__init__() 108 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 109 | 110 | # stream_1, left_img 111 | self.conv1_s1 = conv_block(3, 16, 3, 1) 112 | self.resblock1_s1 = ResBlock(16, 16, 3, 1, 1) 113 | self.resblock2_s1 = ResBlock(16, 16, 3, 1, 2) 114 | 115 | # stream_2, upsampled low_resolution disp 116 | self.conv1_s2 = conv_block(1, 16, 1, 1) 117 | self.resblock1_s2 = ResBlock(16, 16, 3, 1, 1) 118 | self.resblock2_s2 = ResBlock(16, 16, 3, 1, 2) 119 | 120 | # cat 121 | self.resblock3 = ResBlock(32, 32, 3, 1, 4) 122 | self.resblock4 = ResBlock(32, 32, 3, 1, 8) 123 | self.resblock5 = ResBlock(32, 32, 3, 1, 1) 124 | self.resblock6 = ResBlock(32, 32, 3, 1, 1) 125 | self.conv2 = conv_block(32, 1, 3, 1) 126 | 127 | def forward(self, left_img, up_disp): 128 | 129 | stream1 = self.conv1_s1(left_img) 130 | stream1 = self.resblock1_s1(stream1) 131 | stream1 = self.resblock2_s1(stream1) 132 | 133 | stream2 = self.conv1_s2(up_disp) 134 | stream2 = self.resblock1_s2(stream2) 135 | stream2 = self.resblock2_s2(stream2) 136 | 137 | out = torch.cat((stream1, stream2), 1) 138 | out = self.resblock3(out) 139 | out = self.resblock4(out) 140 | out = self.resblock5(out) 141 | out = self.resblock6(out) 142 | out = self.conv2(out) 143 | 144 | return out 145 | 146 | 147 | class InvalidationNet(nn.Module): 148 | def __init__(self): 149 | super(InvalidationNet, self).__init__() 150 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 151 | 152 | resblocks1 = [ResBlock(64, 64, 3, 1, 1)] * 5 153 | self.resblocks1 = nn.Sequential(*resblocks1) 154 | self.conv1 = conv_block(64, 1, 3, 1, norm=None, act=None) 155 | 156 | self.conv2 = conv_block(5, 32, 3, 1) 157 | resblocks2 = [ResBlock(32, 32, 3, 1, 1)] * 4 158 | self.resblocks2 = nn.Sequential(*resblocks2) 159 | self.conv3 = conv_block(32, 1, 3, 1, norm=None, act=None) 160 | 161 | def forward(self, left_tower, right_tower, left_img, freso_disp): 162 | 163 | features = torch.cat((left_tower, right_tower), 1) 164 | out1 = self.resblocks1(features) 165 | out1 = self.conv1(out1) 166 | 167 | input = torch.cat((left_img, out1, freso_disp), 1) 168 | 169 | out2 = self.conv2(input) 170 | out2 = self.resblocks2(out2) 171 | out2 = self.conv3(out2) 172 | 173 | return out2 174 | 175 | 176 | 177 | class ActiveStereoNet(nn.Module): 178 | def __init__(self, maxdisp, scale_factor, img_shape): 179 | super(ActiveStereoNet, self).__init__() 180 | self.maxdisp = maxdisp 181 | self.scale_factor = scale_factor 182 | self.SiameseTower = SiameseTower(scale_factor) 183 | self.CoarseNet = CoarseNet(maxdisp, scale_factor, img_shape) 184 | self.RefineNet = RefineNet() 185 | #self.InvalidationNet = InvalidationNet() 186 | self.img_shpae = img_shape 187 | 188 | 189 | for m in self.modules(): 190 | if isinstance(m, nn.Conv2d): 191 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 192 | m.weight.data.normal_(0, math.sqrt(2. / n)) 193 | elif isinstance(m, nn.Conv3d): 194 | n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels 195 | m.weight.data.normal_(0, math.sqrt(2. / n)) 196 | elif isinstance(m, nn.BatchNorm2d): 197 | m.weight.data.fill_(1) 198 | m.bias.data.zero_() 199 | elif isinstance(m, nn.BatchNorm3d): 200 | m.weight.data.fill_(1) 201 | m.bias.data.zero_() 202 | elif isinstance(m, nn.Linear): 203 | m.bias.data.zero_() 204 | 205 | 206 | 207 | def forward(self, left, right): 208 | 209 | #pdb.set_trace() 210 | left_tower = self.SiameseTower(left) 211 | right_tower = self.SiameseTower(right) 212 | #pdb.set_trace() 213 | coarseup_pred = self.CoarseNet(left_tower, right_tower) 214 | res_disp = self.RefineNet(left, coarseup_pred) 215 | 216 | ref_pred = coarseup_pred + res_disp 217 | 218 | 219 | 220 | return nn.ReLU(False)(ref_pred) 221 | 222 | -------------------------------------------------------------------------------- /Models/StereoNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import numpy as np 6 | import pdb 7 | from .blocks import * 8 | from .ActiveStereoNet import * 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /Models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ActiveStereoNet import ActiveStereoNet 2 | 3 | def get_model(config): 4 | 5 | cfg_net = config['model'] 6 | net_name = cfg_net['which_model'].lower() 7 | if net_name == 'activestereonet': 8 | max_disp = cfg_net['max_disp'] 9 | scale_factor = cfg_net['scale_factor'] 10 | img_shape = config['data']['crop_size'] 11 | model = ActiveStereoNet(max_disp, scale_factor, img_shape) 12 | else: 13 | raise NotImplementedError('Model [{:s}] is not supported.'.format(net_name)) 14 | 15 | return model -------------------------------------------------------------------------------- /Models/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import pdb 6 | 7 | 8 | def conv_block(nc_in, nc_out, k, s, norm='bn', act='lrelu', dilation=1): 9 | 10 | blocks = [ 11 | nn.Conv2d(nc_in, nc_out, k, s, dilation if dilation > 1 else k // 2, dilation=dilation) 12 | ] 13 | ''' 14 | if norm is not None: 15 | norm = norm.lower() 16 | if norm == 'bn': 17 | blocks.append(nn.BatchNorm2d(nc_out)) 18 | elif norm == 'in': 19 | blocks.append(nn.InstanceNorm2d(nc_out)) 20 | else: 21 | raise RuntimeError 22 | ''' 23 | 24 | if act is not None: 25 | act = act.lower() 26 | if act == 'relu': 27 | blocks.append(nn.ReLU(True)) 28 | elif act == 'lrelu': 29 | blocks.append(nn.LeakyReLU(0.2, True)) 30 | else: 31 | raise RuntimeError 32 | 33 | return nn.Sequential(*blocks) 34 | 35 | 36 | def conv3d_block(in_planes, out_planes, kernel_size, stride, norm='bn', act='lrelu'): 37 | 38 | blocks = [ 39 | nn.Conv3d(in_planes, out_planes, kernel_size, stride, kernel_size // 2) 40 | ] 41 | ''' 42 | if norm is not None: 43 | norm = norm.lower() 44 | if norm == 'bn': 45 | blocks.append(nn.BatchNorm3d(out_planes)) 46 | elif norm == 'in': 47 | blocks.append(nn.InstanceNorm3d(out_planes)) 48 | else: 49 | raise RuntimeError 50 | ''' 51 | 52 | if act is not None: 53 | act = act.lower() 54 | if act == 'lrelu': 55 | blocks.append(nn.LeakyReLU(0.2, True)) 56 | elif act == 'relu': 57 | blocks.append(nn.ReLU(True)) 58 | else: 59 | raise RuntimeError 60 | 61 | return nn.Sequential(*blocks) 62 | 63 | 64 | class ResBlock(nn.Module): 65 | 66 | def __init__(self, in_planes, out_planes, kernel_size, stride, dilation=1): 67 | super(ResBlock, self).__init__() 68 | self.conv = conv_block(in_planes, out_planes, kernel_size, stride, norm='bn', act='lrelu', dilation=dilation) 69 | 70 | def forward(self, x): 71 | out = self.conv(x) 72 | out = out + x 73 | return out 74 | 75 | 76 | 77 | class DisparityRegression(nn.Module): 78 | def __init__(self, maxdisp): 79 | super(DisparityRegression, self).__init__() 80 | 81 | self.disp = torch.from_numpy( 82 | np.reshape(np.array(range(maxdisp)), 83 | [1, maxdisp, 1, 1] 84 | )).cuda().float().requires_grad_(False) 85 | 86 | def forward(self, x): 87 | 88 | y = x.mul(self.disp).sum(dim=1, keepdim=True) 89 | 90 | return y 91 | 92 | -------------------------------------------------------------------------------- /Options/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from collections import OrderedDict 5 | from datetime import datetime 6 | 7 | def get_timestamp(): 8 | return datetime.now().strftime('%Y%m%d') 9 | 10 | def parse_opt(opt_path): 11 | # remove comments starting with '//' 12 | json_str = '' 13 | with open(opt_path, 'r') as f: 14 | for line in f: 15 | line = line.split('//')[0] + '\n' 16 | json_str += line 17 | opt = json.loads(json_str, object_pairs_hook=OrderedDict) 18 | 19 | opt['timestamp'] = get_timestamp() 20 | 21 | return opt 22 | -------------------------------------------------------------------------------- /Options/example.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode":"train", 3 | "deterministic": true, 4 | "gpu_ids": "0", 5 | "cpu_threads": "4", 6 | "dataset_name": "SceneFlow", 7 | "imshow": false, 8 | 9 | "data": { 10 | "data_root": "I:/dataset/", 11 | "npy_root": "./Data", 12 | "test_split": 0.1, 13 | "val_split": 0.2, 14 | "crop_size": [960, 540] 15 | }, 16 | 17 | "model": { 18 | "which_model": "ActiveStereoNet", 19 | "max_disp": 144, 20 | "scale_factor": 8, 21 | "loss": "XTLoss" 22 | }, 23 | 24 | "solver": { 25 | "batch_size": 2, 26 | "optimizer_type": "RMSProp", 27 | "lr_init": 1e-3, 28 | "gamma": 0.5, 29 | "milestones": [20000, 30000, 40000, 50000], 30 | "eval_steps": 2000, 31 | "save_steps": 2000, 32 | "max_steps": 60000, 33 | "exp_prefix": "Experiments", 34 | "resume_iter": 32000, 35 | "model_name": "191027" 36 | } 37 | } -------------------------------------------------------------------------------- /Sovlers/__init__.py: -------------------------------------------------------------------------------- 1 | from .solver_train import TrainSolver 2 | from .solver_test import TestSolver 3 | 4 | def get_solver(config): 5 | 6 | mode_cfg = config['mode'].lower() 7 | if mode_cfg == 'train': 8 | solver = TrainSolver(config) 9 | elif mode_cfg == 'test': 10 | solver = TestSolver(config) 11 | else: 12 | raise NotImplementedError('Solver [{:s}] is not supported.'.format(mode_cfg)) 13 | 14 | return solver -------------------------------------------------------------------------------- /Sovlers/solver_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import time 4 | import cv2 5 | import numpy as np 6 | import torch.nn as nn 7 | import matplotlib.pyplot as plt 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | from Data import get_loader 11 | from Models import get_model 12 | from Losses import get_losses 13 | from Metrics.metrics import epe_metric 14 | from Metrics.metrics import tripe_metric 15 | import pdb 16 | 17 | 18 | class TestSolver(object): 19 | def __init__(self, config): 20 | self.config = config 21 | self.cfg_solver = config['solver'] 22 | self.cfg_dataset = config['data'] 23 | self.cfg_model = config['model'] 24 | 25 | self.max_disp = self.cfg_model['max_disp'] 26 | self.model = get_model(self.config) 27 | self.test_loader = get_loader(self.config) 28 | self.imshow = config['imshow'] 29 | 30 | 31 | def load_checkpoint(self): 32 | ckpt_root = os.path.join(self.cfg_solver['exp_prefix'], self.cfg_solver['model_name'], 'models') 33 | ckpt_name = 'iter_{:d}.pth'.format(self.cfg_solver['resume_iter']) 34 | ckpt_full = os.path.join(ckpt_root, ckpt_name) 35 | states = torch.load(ckpt_full, map_location=lambda storage, loc: storage) 36 | 37 | self.model.load_state_dict(states['model_state']) 38 | 39 | def save_results(self, output, target): 40 | 41 | for i in range(output.shape[0]): 42 | 43 | outcmap = output[i] 44 | tarcmap = target[i] 45 | #pdb.set_trace() 46 | outcmap = outcmap.cpu().numpy().astype(np.uint8).squeeze() 47 | tarcmap = tarcmap.cpu().numpy().astype(np.uint8).squeeze() 48 | 49 | #pdb.set_trace() 50 | outcmap = cv2.applyColorMap(outcmap, cv2.COLORMAP_RAINBOW) 51 | tarcmap = cv2.applyColorMap(tarcmap, cv2.COLORMAP_RAINBOW) 52 | 53 | plt.figure(figsize=(640, 840)) 54 | plt.subplot(1,2,1) 55 | plt.imshow(tarcmap) 56 | plt.axis('off') 57 | plt.title('G.T') 58 | 59 | 60 | plt.subplot(1,2,2) 61 | plt.imshow(outcmap) 62 | plt.axis('off') 63 | plt.title('Prediction') 64 | 65 | plt.show() 66 | #pdb.set_trace() 67 | 68 | 69 | 70 | 71 | 72 | 73 | def run(self): 74 | self.model = nn.DataParallel(self.model) 75 | self.model.cuda() 76 | 77 | if self.cfg_solver['resume_iter'] > 0: 78 | self.load_checkpoint() 79 | print('Model loaded.') 80 | 81 | self.model.eval() 82 | 83 | start_time = time.time() 84 | with torch.no_grad(): 85 | EPE_metric = 0.0 86 | TriPE_metric = 0.0 87 | N_total = 0.0 88 | #pdb.set_trace() 89 | for test_batch in self.test_loader: 90 | imgL, imgR, disp_L, _ = test_batch 91 | imgL, imgR, disp_L = imgL.cuda(), imgR.cuda(), disp_L.cuda() 92 | 93 | N_curr = imgL.shape[0] 94 | 95 | disp_pred = self.model(imgL, imgR) 96 | 97 | EPE_metric += epe_metric(disp_L, disp_pred, self.max_disp) * N_curr 98 | TriPE_metric += tripe_metric(disp_L, disp_pred, self.max_disp) * N_curr 99 | if self.imshow: 100 | self.save_results(disp_pred, disp_L) 101 | 102 | N_total += N_curr 103 | 104 | EPE_metric /= N_total 105 | TriPE_metric /= N_total 106 | 107 | elapsed = time.time() - start_time 108 | print( 109 | 'Test: EPE = {:.6f} px, 3PE = {:.3f} %, time = {:.3f} s.'.format( 110 | EPE_metric, TriPE_metric * 100, elapsed / N_total 111 | ) 112 | ) 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /Sovlers/solver_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import time 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | from Data import get_loader 8 | from Models import get_model 9 | from Losses import get_losses 10 | from Metrics.metrics import epe_metric 11 | from Metrics.metrics import tripe_metric 12 | import pdb 13 | 14 | class TrainSolver(object): 15 | 16 | def __init__(self, config): 17 | 18 | self.config = config 19 | self.cfg_solver = config['solver'] 20 | self.cfg_dataset = config['data'] 21 | self.cfg_model = config['model'] 22 | self.reloaded = True if self.cfg_solver['resume_iter'] > 0 else False 23 | 24 | self.max_disp = self.cfg_model['max_disp'] 25 | self.loss_name = self.cfg_model['loss'] 26 | self.train_loader, self.val_loader = get_loader(self.config) 27 | self.model = get_model(self.config) 28 | 29 | self.crit = get_losses(self.loss_name, max_disp=self.max_disp) 30 | 31 | if self.cfg_solver['optimizer_type'].lower() == 'rmsprop': 32 | self.optimizer = optim.RMSprop(self.model.parameters(), lr=self.cfg_solver['lr_init']) 33 | elif self.cfg_solver['optimizer_type'].lower() == 'adam': 34 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.cfg_solver['lr_init']) 35 | else: 36 | raise NotImplementedError('Optimizer type [{:s}] is not supported'.format(self.cfg_solver['optimizer_type'])) 37 | self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.cfg_solver['milestones'], gamma=self.cfg_solver['gamma']) 38 | self.global_step = 1 39 | 40 | def save_checkpoint(self): 41 | 42 | ckpt_root = os.path.join(self.cfg_solver['exp_prefix'], self.cfg_solver['model_name'], 'models') 43 | 44 | if not os.path.exists(ckpt_root): 45 | os.makedirs(ckpt_root) 46 | 47 | ckpt_name = 'iter_{:d}.pth'.format(self.global_step) 48 | states = { 49 | 'global_step': self.global_step, 50 | 'model_state': self.model.state_dict(), 51 | 'optimizer_state': self.optimizer.state_dict(), 52 | 'scheduler_state': self.scheduler.state_dict() 53 | } 54 | ckpt_full = os.path.join(ckpt_root, ckpt_name) 55 | 56 | torch.save(states, ckpt_full) 57 | 58 | def load_checkpoint(self): 59 | 60 | ckpt_root = os.path.join(self.cfg_solver['exp_prefix'], self.cfg_solver['model_name'], 'models') 61 | 62 | ckpt_name = 'iter_{:d}.pth'.format(self.cfg_solver['resume_iter']) 63 | 64 | ckpt_full = os.path.join(ckpt_root, ckpt_name) 65 | 66 | states = torch.load(ckpt_full, map_location=lambda storage, loc: storage) 67 | 68 | self.global_step = states['global_step'] 69 | self.model.load_state_dict(states['model_state']) 70 | self.optimizer.load_state_dict(states['optimizer_state']) 71 | self.scheduler.load_state_dict(states['scheduler_state']) 72 | 73 | def run(self): 74 | self.model = nn.DataParallel(self.model) 75 | self.model.cuda() 76 | 77 | print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in self.model.parameters()]))) 78 | 79 | if self.cfg_solver['resume_iter'] > 0: 80 | self.load_checkpoint() 81 | print('[{:d}] Model loaded.'.format(self.global_step)) 82 | 83 | data_iter = iter(self.train_loader) 84 | while True: 85 | try: 86 | data_batch = data_iter.next() 87 | except StopIteration: 88 | data_iter = iter(self.train_loader) 89 | data_batch = data_iter.next() 90 | 91 | if self.global_step > self.cfg_solver['max_steps']: 92 | break 93 | 94 | start_time = time.time() 95 | 96 | self.model.train() 97 | imgL, imgR, disp_L, _ = data_batch 98 | imgL, imgR, disp_L = imgL.cuda(), imgR.cuda(), disp_L.cuda() 99 | 100 | self.optimizer.zero_grad() 101 | #pdb.set_trace() 102 | disp_pred_left = self.model(imgL, imgR) 103 | 104 | #pdb.set_trace() 105 | 106 | loss = self.crit(imgL, imgR, disp_pred_left) 107 | loss.backward() 108 | self.optimizer.step() 109 | 110 | elapsed = time.time() - start_time 111 | train_EPE_left = epe_metric(disp_L, disp_pred_left, self.max_disp) 112 | train_3PE_left = tripe_metric(disp_L, disp_pred_left, self.max_disp) 113 | 114 | 115 | print( 116 | '[{:d}/{:d}] Train Loss = {:.6f}, EPE = {:.3f} px, 3PE = {:.3f}%, time = {:.3f}s.'.format( 117 | self.global_step, self.cfg_solver['max_steps'], 118 | loss.item(), 119 | train_EPE_left, 120 | train_3PE_left * 100, 121 | elapsed 122 | ), end='\r' 123 | ) 124 | self.scheduler.step() 125 | 126 | if self.global_step % self.cfg_solver['save_steps'] == 0 and not self.reloaded: 127 | self.save_checkpoint() 128 | print('') 129 | print('[{:d}] Model saved.'.format(self.global_step)) 130 | 131 | 132 | if self.global_step % self.cfg_solver['eval_steps'] == 0 and not self.reloaded: 133 | start_time = time.time() 134 | self.model.eval() 135 | with torch.no_grad(): 136 | 137 | val_EPE_metric_left = 0.0 138 | val_TriPE_metric_left = 0.0 139 | N_total = 0.0 140 | 141 | for val_batch in self.val_loader: 142 | imgL, imgR, disp_L, _= val_batch 143 | imgL, imgR, disp_L = imgL.cuda(), imgR.cuda(), disp_L.cuda() 144 | 145 | N_curr = imgL.shape[0] 146 | 147 | disp_pred_left = self.model(imgL, imgR) 148 | 149 | val_EPE_metric_left += epe_metric(disp_L, disp_pred_left, self.max_disp) * N_curr 150 | val_TriPE_metric_left += tripe_metric(disp_L, disp_pred_left, self.max_disp) * N_curr 151 | 152 | N_total += N_curr 153 | 154 | val_EPE_metric_left /= N_total 155 | val_TriPE_metric_left /= N_total 156 | 157 | 158 | elapsed = time.time() - start_time 159 | print( 160 | '[{:d}/{:d}] Validation : EPE = {:.6f} px, 3PE = {:.3f} %, time = {:.3f} s.'.format( 161 | self.global_step, self.cfg_solver['max_steps'], 162 | val_EPE_metric_left, 163 | val_TriPE_metric_left * 100, 164 | elapsed / N_total 165 | ) 166 | ) 167 | 168 | 169 | 170 | self.global_step += 1 171 | 172 | self.reloaded = False -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import numpy as np 5 | import random 6 | import os 7 | from torch.backends import cudnn 8 | 9 | from Options import parse_opt 10 | from Sovlers import get_solver 11 | 12 | def main(): 13 | 14 | # Parse arguments. 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--options', type=str, help='Path to the option JSON file.', default='./Options/example.json') 17 | args = parser.parse_args() 18 | opt = parse_opt(args.options) 19 | 20 | # GPU/CPU Specification. 21 | os.environ['CUDA_VISIBLE_DEVICES'] = opt['gpu_ids'] 22 | os.environ['MKL_NUM_THREADS'] = opt['cpu_threads'] 23 | os.environ['NUMEXPR_NUM_THREADS'] = opt['cpu_threads'] 24 | os.environ['OMP_NUM_THREADS'] = opt['cpu_threads'] 25 | 26 | # Deterministic Settings. 27 | if opt['deterministic']: 28 | torch.manual_seed(712) 29 | np.random.seed(712) 30 | random.seed(712) 31 | cudnn.deterministic = True 32 | cudnn.benchmark = False 33 | else: 34 | cudnn.deterministic = False 35 | cudnn.benchmark = True 36 | 37 | # Create solver. 38 | solver = get_solver(opt) 39 | 40 | # Run. 41 | solver.run() 42 | 43 | if __name__ == "__main__": 44 | main() -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # ActiveStereoNet (Pytorch Implementation) 2 | 3 | #### [Paper] [ActiveStereoNet: End-to-End Self-Supervised Learning for Active Stereo Systems](https://arxiv.org/abs/1807.06009) 4 | 5 | This repository provides a PyTorch implementation of ActiveStereoNet (not official). 6 | 7 | ## Dependencies and Installation 8 | 9 | - Python 3 (Recommend to use Anaconda) 10 | 11 | - PyTorch 1.1.0 12 | 13 | - NVIDIA GPU + CUDA 14 | 15 | ## Dataset Preparation 16 | 17 | - TODO 18 | 19 | ## Usage 20 | 21 | - TODO --------------------------------------------------------------------------------