├── Images ├── diffuvolume.png ├── infer.png └── zero.png ├── KITTI12 ├── LICENSE ├── datasets │ ├── MiddleburyLoader.py │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── MiddleburyLoader.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── data_io.cpython-37.pyc │ │ ├── data_io.cpython-38.pyc │ │ ├── eth3dLoader.cpython-37.pyc │ │ ├── flow_transforms.cpython-37.pyc │ │ ├── flow_transforms.cpython-38.pyc │ │ ├── kitti_dataset.cpython-37.pyc │ │ ├── kitti_dataset.cpython-38.pyc │ │ ├── listfiles.cpython-37.pyc │ │ ├── readpfm.cpython-37.pyc │ │ ├── sceneflow_dataset.cpython-37.pyc │ │ └── sceneflow_dataset.cpython-38.pyc │ ├── data_io.py │ ├── data_io.pyc │ ├── eth3dLoader.py │ ├── flow_transforms.py │ ├── kitti_dataset.py │ ├── kitti_dataset.pyc │ ├── kitti_dataset_small.py │ ├── listfiles.py │ ├── readpfm.py │ ├── sceneflow_dataset.py │ └── sceneflow_dataset.pyc ├── filenames │ ├── kitti12_all.txt │ ├── kitti12_test.txt │ ├── kitti12_train.txt │ └── kitti12_val.txt ├── main.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── gwcnet.cpython-37.pyc │ │ ├── head.cpython-38.pyc │ │ ├── loss.cpython-37.pyc │ │ ├── loss.cpython-38.pyc │ │ ├── pwcnet.cpython-37.pyc │ │ ├── pwcnet.cpython-38.pyc │ │ ├── pwcnet_ddim.cpython-38.pyc │ │ ├── submodule.cpython-37.pyc │ │ └── submodule.cpython-38.pyc │ ├── head.py │ ├── loss.py │ ├── pwcnet.py │ ├── pwcnet_ddim.py │ ├── relu │ │ ├── pwcnet.py │ │ └── submodule.py │ └── submodule.py ├── save_disp_sceneflow_kitti12.py ├── scripts │ └── kitti12.sh ├── test.py └── utils │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── experiment.cpython-37.pyc │ ├── experiment.cpython-38.pyc │ ├── metrics.cpython-37.pyc │ ├── metrics.cpython-38.pyc │ ├── visualization.cpython-37.pyc │ └── visualization.cpython-38.pyc │ ├── experiment.py │ ├── experiment.pyc │ ├── metrics.py │ ├── metrics.pyc │ ├── visualization.py │ └── visualization.pyc ├── KITTI15 ├── core │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── extractor.cpython-37.pyc │ │ ├── extractor.cpython-38.pyc │ │ ├── geometry.cpython-37.pyc │ │ ├── geometry.cpython-38.pyc │ │ ├── geometry_ddim.cpython-37.pyc │ │ ├── geometry_ddim.cpython-38.pyc │ │ ├── head.cpython-37.pyc │ │ ├── head.cpython-38.pyc │ │ ├── igev_stereo.cpython-37.pyc │ │ ├── igev_stereo.cpython-38.pyc │ │ ├── igev_stereo_ddim.cpython-37.pyc │ │ ├── igev_stereo_ddim.cpython-38.pyc │ │ ├── stereo_datasets.cpython-37.pyc │ │ ├── stereo_datasets.cpython-38.pyc │ │ ├── submodule.cpython-37.pyc │ │ ├── submodule.cpython-38.pyc │ │ ├── update.cpython-37.pyc │ │ └── update.cpython-38.pyc │ ├── extractor.py │ ├── geometry.py │ ├── geometry_ddim.py │ ├── head.py │ ├── igev_stereo.py │ ├── igev_stereo_ddim.py │ ├── stereo_datasets.py │ ├── submodule.py │ ├── update.py │ └── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── augmentor.cpython-37.pyc │ │ ├── augmentor.cpython-38.pyc │ │ ├── frame_utils.cpython-37.pyc │ │ ├── frame_utils.cpython-38.pyc │ │ ├── utils.cpython-37.pyc │ │ └── utils.cpython-38.pyc │ │ ├── augmentor.py │ │ ├── frame_utils.py │ │ └── utils.py ├── evaluate_stereo.py ├── evaluate_stereo_origin.py ├── run.sh ├── save_disp.py └── train_stereo.py ├── LICENSE.txt ├── README.md └── SceneFlow ├── LICENSE ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── data_io.cpython-38.pyc │ ├── flow_transforms.cpython-38.pyc │ ├── kitti_dataset.cpython-38.pyc │ ├── kitti_dataset_1215.cpython-38.pyc │ └── sceneflow_dataset.cpython-38.pyc ├── data_io.py ├── flow_transforms.py ├── kitti_dataset.py ├── kitti_dataset_1215.py └── sceneflow_dataset.py ├── filenames ├── sceneflow_test.txt ├── sceneflow_test_spe.txt ├── sceneflow_train.txt └── train_scene_flow.txt ├── main.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── acv.cpython-38.pyc │ ├── acv_ddim.cpython-38.pyc │ ├── acv_ddim_lowD.cpython-38.pyc │ ├── acv_ddpm.cpython-38.pyc │ ├── head.cpython-38.pyc │ ├── loss.cpython-38.pyc │ ├── pwcnet.cpython-38.pyc │ └── submodule.cpython-38.pyc ├── acv.py ├── acv_ddim.py ├── head.py ├── loss.py ├── submodule.py └── temp.py ├── save_disp_sceneflow.py ├── submodule.py ├── test_sceneflow_ddim.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-38.pyc ├── experiment.cpython-38.pyc ├── metrics.cpython-38.pyc ├── misc.cpython-38.pyc └── visualization.cpython-38.pyc ├── experiment.py ├── metrics.py ├── misc.py └── visualization.py /Images/diffuvolume.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/Images/diffuvolume.png -------------------------------------------------------------------------------- /Images/infer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/Images/infer.png -------------------------------------------------------------------------------- /Images/zero.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/Images/zero.png -------------------------------------------------------------------------------- /KITTI12/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Xiaoyang Guo, Kai Yang, Wukui Yang, Xiaogang Wang, Hongsheng Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /KITTI12/datasets/MiddleburyLoader.py: -------------------------------------------------------------------------------- 1 | import os, torch, torch.utils.data as data 2 | from PIL import Image 3 | import numpy as np 4 | from . import flow_transforms 5 | import pdb 6 | import torchvision 7 | import warnings 8 | from . import readpfm as rp 9 | from datasets.data_io import get_transform, read_all_lines 10 | warnings.filterwarnings('ignore', '.*output shape of zoom.*') 11 | import cv2 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'] 16 | 17 | def is_image_file(filename): 18 | return any((filename.endswith(extension) for extension in IMG_EXTENSIONS)) 19 | 20 | 21 | def default_loader(path): 22 | return Image.open(path).convert('RGB') 23 | 24 | 25 | def disparity_loader(path): 26 | if '.png' in path: 27 | data = Image.open(path) 28 | data = np.ascontiguousarray(data,dtype=np.float32)/256 29 | return data 30 | else: 31 | data = rp.readPFM(path)[0] 32 | data = np.ascontiguousarray(data, dtype=np.float32) 33 | return data 34 | 35 | 36 | class myImageFloder(data.Dataset): 37 | 38 | def __init__(self, left, right, left_disparity, training, right_disparity=None, loader=default_loader, dploader=disparity_loader): 39 | self.left = left 40 | self.right = right 41 | self.disp_L = left_disparity 42 | self.disp_R = right_disparity 43 | self.training = training 44 | self.loader = loader 45 | self.dploader = dploader 46 | self.order = 0 47 | 48 | def __getitem__(self, index): 49 | left = self.left[index] 50 | right = self.right[index] 51 | left_img = self.loader(left) 52 | right_img = self.loader(right) 53 | if self.disp_L is not None: 54 | disp_L = self.disp_L[index] 55 | disparity = self.dploader(disp_L) 56 | disparity[disparity == np.inf] = 0 57 | else: 58 | disparity = None 59 | 60 | if self.training: 61 | th, tw = 256, 512 62 | #th, tw = 320, 704 63 | random_brightness = np.random.uniform(0.5, 2.0, 2) 64 | random_gamma = np.random.uniform(0.8, 1.2, 2) 65 | random_contrast = np.random.uniform(0.8, 1.2, 2) 66 | left_img = torchvision.transforms.functional.adjust_brightness(left_img, random_brightness[0]) 67 | left_img = torchvision.transforms.functional.adjust_gamma(left_img, random_gamma[0]) 68 | left_img = torchvision.transforms.functional.adjust_contrast(left_img, random_contrast[0]) 69 | right_img = torchvision.transforms.functional.adjust_brightness(right_img, random_brightness[1]) 70 | right_img = torchvision.transforms.functional.adjust_gamma(right_img, random_gamma[1]) 71 | right_img = torchvision.transforms.functional.adjust_contrast(right_img, random_contrast[1]) 72 | right_img = np.asarray(right_img) 73 | left_img = np.asarray(left_img) 74 | 75 | # w, h = left_img.size 76 | # th, tw = 256, 512 77 | # 78 | # x1 = random.randint(0, w - tw) 79 | # y1 = random.randint(0, h - th) 80 | # 81 | # left_img = left_img.crop((x1, y1, x1 + tw, y1 + th)) 82 | # right_img = right_img.crop((x1, y1, x1 + tw, y1 + th)) 83 | # dataL = dataL[y1:y1 + th, x1:x1 + tw] 84 | # right_img = np.asarray(right_img) 85 | # left_img = np.asarray(left_img) 86 | 87 | # geometric unsymmetric-augmentation 88 | angle = 0; 89 | px = 0 90 | if np.random.binomial(1, 0.5): 91 | # angle = 0.1; 92 | # px = 2 93 | angle = 0.05 94 | px = 1 95 | co_transform = flow_transforms.Compose([ 96 | # flow_transforms.RandomVdisp(angle, px), 97 | flow_transforms.Scale(0.5, order=self.order), 98 | flow_transforms.RandomCrop((th, tw)), 99 | ]) 100 | augmented, disparity = co_transform([left_img, right_img], disparity) 101 | left_img = augmented[0] 102 | right_img = augmented[1] 103 | 104 | right_img.flags.writeable = True 105 | if np.random.binomial(1,0.2): 106 | sx = int(np.random.uniform(35,100)) 107 | sy = int(np.random.uniform(25,75)) 108 | cx = int(np.random.uniform(sx,right_img.shape[0]-sx)) 109 | cy = int(np.random.uniform(sy,right_img.shape[1]-sy)) 110 | right_img[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(right_img,0),0)[np.newaxis,np.newaxis] 111 | 112 | # to tensor, normalize 113 | disparity = np.ascontiguousarray(disparity, dtype=np.float32) 114 | processed = get_transform() 115 | left_img = processed(left_img) 116 | right_img = processed(right_img) 117 | 118 | return {"left": left_img, 119 | "right": right_img, 120 | "disparity": disparity} 121 | else: 122 | # w, h = left_img.size 123 | right_img = np.asarray(right_img) 124 | left_img = np.asarray(left_img) 125 | # co_transform = flow_transforms.Compose([ 126 | # # flow_transforms.RandomVdisp(angle, px), 127 | # flow_transforms.Scale(0.5, order=self.order), 128 | # # flow_transforms.RandomCrop((th, tw)), 129 | # ]) 130 | # augmented, disparity = co_transform([left_img, right_img], disparity) 131 | # left_img = augmented[0] 132 | # right_img = augmented[1] 133 | # right_img = cv2.resize(right_img, None, fx=0.5,fy=0.5 ,interpolation=cv2.INTER_CUBIC) 134 | # left_img = cv2.resize(left_img, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_CUBIC) 135 | disparity = np.ascontiguousarray(disparity, dtype=np.float32) 136 | # normalize 137 | h = left_img.shape[0] 138 | w = left_img.shape[1] 139 | processed = get_transform() 140 | left_img = processed(left_img).numpy() 141 | right_img = processed(right_img).numpy() 142 | # h, w, _ = left_img.shape 143 | # pad to size 1248x384 144 | top_pad = 32 - (h % 32) 145 | right_pad = 32 - (w % 32) 146 | assert top_pad > 0 and right_pad > 0 147 | # pad images 148 | left_img = np.lib.pad(left_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='constant', constant_values=0) 149 | right_img = np.lib.pad(right_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='constant', 150 | constant_values=0) 151 | # pad disparity gt 152 | if disparity is not None: 153 | assert len(disparity.shape) == 2 154 | disparity = np.lib.pad(disparity, ((top_pad, 0), (0, right_pad)), mode='constant', constant_values=0) 155 | 156 | if disparity is not None: 157 | return {"left": left_img, 158 | "right": right_img, 159 | "disparity": disparity, 160 | "top_pad": top_pad, 161 | "right_pad": right_pad, 162 | "left_filename": self.left[index], 163 | "right_filename": self.right[index] 164 | } 165 | else: 166 | return {"left": left_img, 167 | "right": right_img, 168 | "top_pad": top_pad, 169 | "right_pad": right_pad, 170 | "left_filename": self.left[index], 171 | "right_filename": self.right[index]} 172 | 173 | def __len__(self): 174 | return len(self.left) 175 | -------------------------------------------------------------------------------- /KITTI12/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .kitti_dataset import KITTIDataset 2 | from .sceneflow_dataset import SceneFlowDatset 3 | 4 | __datasets__ = { 5 | "sceneflow": SceneFlowDatset, 6 | "kitti": KITTIDataset 7 | } 8 | -------------------------------------------------------------------------------- /KITTI12/datasets/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/__init__.pyc -------------------------------------------------------------------------------- /KITTI12/datasets/__pycache__/MiddleburyLoader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/__pycache__/MiddleburyLoader.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI12/datasets/__pycache__/data_io.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/__pycache__/data_io.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/datasets/__pycache__/data_io.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/__pycache__/data_io.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI12/datasets/__pycache__/eth3dLoader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/__pycache__/eth3dLoader.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/datasets/__pycache__/flow_transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/__pycache__/flow_transforms.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/datasets/__pycache__/flow_transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/__pycache__/flow_transforms.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI12/datasets/__pycache__/kitti_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/__pycache__/kitti_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/datasets/__pycache__/kitti_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/__pycache__/kitti_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI12/datasets/__pycache__/listfiles.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/__pycache__/listfiles.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/datasets/__pycache__/readpfm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/__pycache__/readpfm.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/datasets/__pycache__/sceneflow_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/__pycache__/sceneflow_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/datasets/__pycache__/sceneflow_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/__pycache__/sceneflow_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI12/datasets/data_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import torchvision.transforms as transforms 4 | 5 | 6 | def get_transform(): 7 | mean = [0.485, 0.456, 0.406] 8 | std = [0.229, 0.224, 0.225] 9 | 10 | return transforms.Compose([ 11 | transforms.ToTensor(), 12 | transforms.Normalize(mean=mean, std=std), 13 | ]) 14 | 15 | 16 | # read all lines in a file 17 | def read_all_lines(filename): 18 | with open(filename) as f: 19 | lines = [line.rstrip() for line in f.readlines()] 20 | return lines 21 | 22 | 23 | # read an .pfm file into numpy array, used to load SceneFlow disparity files 24 | def pfm_imread(filename): 25 | file = open(filename, 'rb') 26 | color = None 27 | width = None 28 | height = None 29 | scale = None 30 | endian = None 31 | 32 | header = file.readline().decode('utf-8').rstrip() 33 | if header == 'PF': 34 | color = True 35 | elif header == 'Pf': 36 | color = False 37 | else: 38 | raise Exception('Not a PFM file.') 39 | 40 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 41 | if dim_match: 42 | width, height = map(int, dim_match.groups()) 43 | else: 44 | raise Exception('Malformed PFM header.') 45 | 46 | scale = float(file.readline().rstrip()) 47 | if scale < 0: # little-endian 48 | endian = '<' 49 | scale = -scale 50 | else: 51 | endian = '>' # big-endian 52 | 53 | data = np.fromfile(file, endian + 'f') 54 | shape = (height, width, 3) if color else (height, width) 55 | 56 | data = np.reshape(data, shape) 57 | data = np.flipud(data) 58 | return data, scale 59 | -------------------------------------------------------------------------------- /KITTI12/datasets/data_io.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/data_io.pyc -------------------------------------------------------------------------------- /KITTI12/datasets/eth3dLoader.py: -------------------------------------------------------------------------------- 1 | import os, torch, torch.utils.data as data 2 | from PIL import Image 3 | import numpy as np 4 | from . import flow_transforms 5 | import pdb 6 | import torchvision 7 | import warnings 8 | from . import readpfm as rp 9 | from datasets.data_io import get_transform, read_all_lines 10 | warnings.filterwarnings('ignore', '.*output shape of zoom.*') 11 | 12 | IMG_EXTENSIONS = [ 13 | '.jpg', '.JPG', '.jpeg', '.JPEG', 14 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'] 15 | 16 | def is_image_file(filename): 17 | return any((filename.endswith(extension) for extension in IMG_EXTENSIONS)) 18 | 19 | 20 | def default_loader(path): 21 | return Image.open(path).convert('RGB') 22 | 23 | 24 | def disparity_loader(path): 25 | if '.png' in path: 26 | data = Image.open(path) 27 | data = np.ascontiguousarray(data,dtype=np.float32)/256 28 | return data 29 | else: 30 | data = rp.readPFM(path)[0] 31 | data = np.ascontiguousarray(data, dtype=np.float32) 32 | return data 33 | 34 | 35 | class myImageFloder(data.Dataset): 36 | 37 | def __init__(self, left, right, left_disparity, training, right_disparity=None, loader=default_loader, dploader=disparity_loader): 38 | self.left = left 39 | self.right = right 40 | self.disp_L = left_disparity 41 | self.disp_R = right_disparity 42 | self.training = training 43 | self.loader = loader 44 | self.dploader = dploader 45 | 46 | def __getitem__(self, index): 47 | left = self.left[index] 48 | right = self.right[index] 49 | left_img = self.loader(left) 50 | right_img = self.loader(right) 51 | if self.disp_L is not None: 52 | disp_L = self.disp_L[index] 53 | disparity = self.dploader(disp_L) 54 | disparity[disparity == np.inf] = 0 55 | else: 56 | disparity = None 57 | 58 | if self.training: 59 | th, tw = 256, 512 60 | #th, tw = 320, 704 61 | random_brightness = np.random.uniform(0.5, 2.0, 2) 62 | random_gamma = np.random.uniform(0.8, 1.2, 2) 63 | random_contrast = np.random.uniform(0.8, 1.2, 2) 64 | left_img = torchvision.transforms.functional.adjust_brightness(left_img, random_brightness[0]) 65 | left_img = torchvision.transforms.functional.adjust_gamma(left_img, random_gamma[0]) 66 | left_img = torchvision.transforms.functional.adjust_contrast(left_img, random_contrast[0]) 67 | right_img = torchvision.transforms.functional.adjust_brightness(right_img, random_brightness[1]) 68 | right_img = torchvision.transforms.functional.adjust_gamma(right_img, random_gamma[1]) 69 | right_img = torchvision.transforms.functional.adjust_contrast(right_img, random_contrast[1]) 70 | right_img = np.asarray(right_img) 71 | left_img = np.asarray(left_img) 72 | 73 | # w, h = left_img.size 74 | # th, tw = 256, 512 75 | # 76 | # x1 = random.randint(0, w - tw) 77 | # y1 = random.randint(0, h - th) 78 | # 79 | # left_img = left_img.crop((x1, y1, x1 + tw, y1 + th)) 80 | # right_img = right_img.crop((x1, y1, x1 + tw, y1 + th)) 81 | # dataL = dataL[y1:y1 + th, x1:x1 + tw] 82 | # right_img = np.asarray(right_img) 83 | # left_img = np.asarray(left_img) 84 | 85 | # geometric unsymmetric-augmentation 86 | angle = 0; 87 | px = 0 88 | if np.random.binomial(1, 0.5): 89 | # angle = 0.1; 90 | # px = 2 91 | angle = 0.05 92 | px = 1 93 | co_transform = flow_transforms.Compose([ 94 | # flow_transforms.RandomVdisp(angle, px), 95 | # flow_transforms.Scale(np.random.uniform(self.rand_scale[0], self.rand_scale[1]), order=self.order), 96 | flow_transforms.RandomCrop((th, tw)), 97 | ]) 98 | augmented, disparity = co_transform([left_img, right_img], disparity) 99 | left_img = augmented[0] 100 | right_img = augmented[1] 101 | 102 | right_img.flags.writeable = True 103 | if np.random.binomial(1,0.2): 104 | sx = int(np.random.uniform(35,100)) 105 | sy = int(np.random.uniform(25,75)) 106 | cx = int(np.random.uniform(sx,right_img.shape[0]-sx)) 107 | cy = int(np.random.uniform(sy,right_img.shape[1]-sy)) 108 | right_img[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(right_img,0),0)[np.newaxis,np.newaxis] 109 | 110 | # to tensor, normalize 111 | disparity = np.ascontiguousarray(disparity, dtype=np.float32) 112 | processed = get_transform() 113 | left_img = processed(left_img) 114 | right_img = processed(right_img) 115 | 116 | return {"left": left_img, 117 | "right": right_img, 118 | "disparity": disparity} 119 | else: 120 | w, h = left_img.size 121 | 122 | # normalize 123 | processed = get_transform() 124 | left_img = processed(left_img).numpy() 125 | right_img = processed(right_img).numpy() 126 | 127 | # pad to size 1248x384 128 | top_pad = 32 - (h % 32) 129 | right_pad = 32 - (w % 32) 130 | assert top_pad > 0 and right_pad > 0 131 | # pad images 132 | left_img = np.lib.pad(left_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='constant', constant_values=0) 133 | right_img = np.lib.pad(right_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='constant', 134 | constant_values=0) 135 | # pad disparity gt 136 | if disparity is not None: 137 | assert len(disparity.shape) == 2 138 | disparity = np.lib.pad(disparity, ((top_pad, 0), (0, right_pad)), mode='constant', constant_values=0) 139 | 140 | if disparity is not None: 141 | return {"left": left_img, 142 | "right": right_img, 143 | "disparity": disparity, 144 | "top_pad": top_pad, 145 | "right_pad": right_pad} 146 | else: 147 | return {"left": left_img, 148 | "right": right_img, 149 | "top_pad": top_pad, 150 | "right_pad": right_pad, 151 | "left_filename": self.left[index], 152 | "right_filename": self.right[index]} 153 | 154 | def __len__(self): 155 | return len(self.left) 156 | -------------------------------------------------------------------------------- /KITTI12/datasets/flow_transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import random 4 | import numpy as np 5 | import numbers 6 | import pdb 7 | import cv2 8 | 9 | 10 | class Compose(object): 11 | """ Composes several co_transforms together. 12 | """ 13 | 14 | def __init__(self, co_transforms): 15 | self.co_transforms = co_transforms 16 | 17 | def __call__(self, input, target): 18 | for t in self.co_transforms: 19 | input,target = t(input,target) 20 | return input,target 21 | 22 | 23 | 24 | class Scale(object): 25 | """ Rescales the inputs and target arrays to the given 'size'. 26 | """ 27 | 28 | def __init__(self, size, order=2): 29 | self.ratio = size 30 | self.order = order 31 | if order==0: 32 | self.code=cv2.INTER_NEAREST 33 | elif order==1: 34 | self.code=cv2.INTER_LINEAR 35 | elif order==2: 36 | self.code=cv2.INTER_CUBIC 37 | 38 | def __call__(self, inputs, target): 39 | h, w, _ = inputs[0].shape 40 | ratio = self.ratio 41 | 42 | inputs[0] = cv2.resize(inputs[0], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_CUBIC) 43 | inputs[1] = cv2.resize(inputs[1], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_CUBIC) 44 | target = cv2.resize(target, None, fx=ratio,fy=ratio,interpolation=self.code) * ratio 45 | 46 | return inputs, target 47 | 48 | 49 | class RandomCrop(object): 50 | """ Randomly crop images 51 | """ 52 | 53 | def __init__(self, size): 54 | if isinstance(size, numbers.Number): 55 | self.size = (int(size), int(size)) 56 | else: 57 | self.size = size 58 | 59 | def __call__(self, inputs,target): 60 | h, w, _ = inputs[0].shape 61 | th, tw = self.size 62 | if w < tw: tw=w 63 | if h < th: th=h 64 | 65 | x1 = random.randint(0, w - tw) 66 | y1 = random.randint(0, h - th) 67 | inputs[0] = inputs[0][y1: y1 + th,x1: x1 + tw] 68 | inputs[1] = inputs[1][y1: y1 + th,x1: x1 + tw] 69 | return inputs, target[y1: y1 + th,x1: x1 + tw] 70 | 71 | 72 | class RandomVdisp(object): 73 | """Random vertical disparity augmentation 74 | """ 75 | 76 | def __init__(self, angle, px, diff_angle=0, order=2, reshape=False): 77 | self.angle = angle 78 | self.reshape = reshape 79 | self.order = order 80 | self.diff_angle = diff_angle 81 | self.px = px 82 | 83 | def __call__(self, inputs,target): 84 | px2 = random.uniform(-self.px,self.px) 85 | angle2 = random.uniform(-self.angle,self.angle) 86 | 87 | image_center = (np.random.uniform(0,inputs[1].shape[0]),\ 88 | np.random.uniform(0,inputs[1].shape[1])) 89 | rot_mat = cv2.getRotationMatrix2D(image_center, angle2, 1.0) 90 | inputs[1] = cv2.warpAffine(inputs[1], rot_mat, inputs[1].shape[1::-1], flags=cv2.INTER_LINEAR) 91 | trans_mat = np.float32([[1,0,0],[0,1,px2]]) 92 | inputs[1] = cv2.warpAffine(inputs[1], trans_mat, inputs[1].shape[1::-1], flags=cv2.INTER_LINEAR) 93 | return inputs,target 94 | -------------------------------------------------------------------------------- /KITTI12/datasets/kitti_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | import numpy as np 6 | from datasets.data_io import get_transform, read_all_lines 7 | from . import flow_transforms 8 | import torchvision 9 | 10 | 11 | class KITTIDataset(Dataset): 12 | def __init__(self, datapath, list_filename, training): 13 | self.datapath = datapath 14 | self.left_filenames, self.right_filenames, self.disp_filenames = self.load_path(list_filename) 15 | self.training = training 16 | if self.training: 17 | assert self.disp_filenames is not None 18 | 19 | def load_path(self, list_filename): 20 | lines = read_all_lines(list_filename) 21 | splits = [line.split() for line in lines] 22 | left_images = [x[0] for x in splits] 23 | right_images = [x[1] for x in splits] 24 | if len(splits[0]) == 2: # ground truth not available 25 | return left_images, right_images, None 26 | else: 27 | disp_images = [x[2] for x in splits] 28 | return left_images, right_images, disp_images 29 | 30 | def load_image(self, filename): 31 | return Image.open(filename).convert('RGB') 32 | 33 | def load_disp(self, filename): 34 | data = Image.open(filename) 35 | data = np.array(data, dtype=np.float32) / 256. 36 | return data 37 | 38 | def __len__(self): 39 | return len(self.left_filenames) 40 | 41 | def __getitem__(self, index): 42 | left_img = self.load_image(os.path.join(self.datapath, self.left_filenames[index])) 43 | right_img = self.load_image(os.path.join(self.datapath, self.right_filenames[index])) 44 | 45 | if self.disp_filenames: # has disparity ground truth 46 | disparity = self.load_disp(os.path.join(self.datapath, self.disp_filenames[index])) 47 | else: 48 | disparity = None 49 | 50 | if self.training: 51 | th, tw = 256, 512 52 | #th, tw = 320, 1216 53 | #th, tw = 320, 704 54 | random_brightness = np.random.uniform(0.5, 2.0, 2) 55 | random_gamma = np.random.uniform(0.8, 1.2, 2) 56 | random_contrast = np.random.uniform(0.8, 1.2, 2) 57 | left_img = torchvision.transforms.functional.adjust_brightness(left_img, random_brightness[0]) 58 | left_img = torchvision.transforms.functional.adjust_gamma(left_img, random_gamma[0]) 59 | left_img = torchvision.transforms.functional.adjust_contrast(left_img, random_contrast[0]) 60 | right_img = torchvision.transforms.functional.adjust_brightness(right_img, random_brightness[1]) 61 | right_img = torchvision.transforms.functional.adjust_gamma(right_img, random_gamma[1]) 62 | right_img = torchvision.transforms.functional.adjust_contrast(right_img, random_contrast[1]) 63 | right_img = np.asarray(right_img) 64 | left_img = np.asarray(left_img) 65 | 66 | # w, h = left_img.size 67 | # th, tw = 256, 512 68 | # 69 | # x1 = random.randint(0, w - tw) 70 | # y1 = random.randint(0, h - th) 71 | # 72 | # left_img = left_img.crop((x1, y1, x1 + tw, y1 + th)) 73 | # right_img = right_img.crop((x1, y1, x1 + tw, y1 + th)) 74 | # dataL = dataL[y1:y1 + th, x1:x1 + tw] 75 | # right_img = np.asarray(right_img) 76 | # left_img = np.asarray(left_img) 77 | 78 | # geometric unsymmetric-augmentation 79 | angle = 0; 80 | px = 0 81 | if np.random.binomial(1, 0.5): 82 | # angle = 0.1; 83 | # px = 2 84 | angle = 0.05 85 | px = 1 86 | co_transform = flow_transforms.Compose([ 87 | # flow_transforms.RandomVdisp(angle, px), 88 | # flow_transforms.Scale(np.random.uniform(self.rand_scale[0], self.rand_scale[1]), order=self.order), 89 | flow_transforms.RandomCrop((th, tw)), 90 | ]) 91 | augmented, disparity = co_transform([left_img, right_img], disparity) 92 | left_img = augmented[0] 93 | right_img = augmented[1] 94 | 95 | # right_img.flags.writeable = True 96 | if np.random.binomial(1,0.2): 97 | sx = int(np.random.uniform(35,100)) 98 | sy = int(np.random.uniform(25,75)) 99 | cx = int(np.random.uniform(sx,right_img.shape[0]-sx)) 100 | cy = int(np.random.uniform(sy,right_img.shape[1]-sy)) 101 | right_img[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(right_img,0),0)[np.newaxis,np.newaxis] 102 | 103 | # to tensor, normalize 104 | disparity = np.ascontiguousarray(disparity, dtype=np.float32) 105 | processed = get_transform() 106 | left_img = processed(left_img) 107 | right_img = processed(right_img) 108 | 109 | return {"left": left_img, 110 | "right": right_img, 111 | "disparity": disparity} 112 | else: 113 | w, h = left_img.size 114 | 115 | # normalize 116 | processed = get_transform() 117 | left_img = processed(left_img).numpy() 118 | right_img = processed(right_img).numpy() 119 | 120 | # pad to size 1248x384 121 | top_pad = 384 - h 122 | right_pad = 1248 - w 123 | assert top_pad > 0 and right_pad > 0 124 | # pad images 125 | left_img = np.lib.pad(left_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='constant', constant_values=0) 126 | right_img = np.lib.pad(right_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='constant', 127 | constant_values=0) 128 | # pad disparity gt 129 | if disparity is not None: 130 | assert len(disparity.shape) == 2 131 | disparity = np.lib.pad(disparity, ((top_pad, 0), (0, right_pad)), mode='constant', constant_values=0) 132 | 133 | if disparity is not None: 134 | return {"left": left_img, 135 | "right": right_img, 136 | "disparity": disparity, 137 | "top_pad": top_pad, 138 | "right_pad": right_pad, 139 | "left_filename": self.left_filenames[index]} 140 | else: 141 | return {"left": left_img, 142 | "right": right_img, 143 | "top_pad": top_pad, 144 | "right_pad": right_pad, 145 | "left_filename": self.left_filenames[index], 146 | "right_filename": self.right_filenames[index]} 147 | -------------------------------------------------------------------------------- /KITTI12/datasets/kitti_dataset.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/kitti_dataset.pyc -------------------------------------------------------------------------------- /KITTI12/datasets/kitti_dataset_small.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | import numpy as np 6 | from datasets.data_io import get_transform, read_all_lines 7 | from . import flow_transforms 8 | import torchvision 9 | 10 | 11 | class KITTIDataset(Dataset): 12 | def __init__(self, datapath, list_filename, training): 13 | self.datapath = datapath 14 | self.left_filenames, self.right_filenames, self.disp_filenames = self.load_path(list_filename) 15 | self.training = training 16 | if self.training: 17 | assert self.disp_filenames is not None 18 | 19 | def load_path(self, list_filename): 20 | lines = read_all_lines(list_filename) 21 | splits = [line.split() for line in lines] 22 | left_images = [x[0] for x in splits] 23 | right_images = [x[1] for x in splits] 24 | if len(splits[0]) == 2: # ground truth not available 25 | return left_images, right_images, None 26 | else: 27 | disp_images = [x[2] for x in splits] 28 | return left_images, right_images, disp_images 29 | 30 | def load_image(self, filename): 31 | return Image.open(filename).convert('RGB') 32 | 33 | def load_disp(self, filename): 34 | data = Image.open(filename) 35 | data = np.array(data, dtype=np.float32) / 256. 36 | return data 37 | 38 | def __len__(self): 39 | return len(self.left_filenames) 40 | 41 | def __getitem__(self, index): 42 | left_img = self.load_image(os.path.join(self.datapath, self.left_filenames[index])) 43 | right_img = self.load_image(os.path.join(self.datapath, self.right_filenames[index])) 44 | 45 | if self.disp_filenames: # has disparity ground truth 46 | disparity = self.load_disp(os.path.join(self.datapath, self.disp_filenames[index])) 47 | else: 48 | disparity = None 49 | 50 | if self.training: 51 | th, tw = 256, 512 52 | #th, tw = 320, 704 53 | random_brightness = np.random.uniform(0.5, 2.0, 2) 54 | random_gamma = np.random.uniform(0.8, 1.2, 2) 55 | random_contrast = np.random.uniform(0.8, 1.2, 2) 56 | left_img = torchvision.transforms.functional.adjust_brightness(left_img, random_brightness[0]) 57 | left_img = torchvision.transforms.functional.adjust_gamma(left_img, random_gamma[0]) 58 | left_img = torchvision.transforms.functional.adjust_contrast(left_img, random_contrast[0]) 59 | right_img = torchvision.transforms.functional.adjust_brightness(right_img, random_brightness[1]) 60 | right_img = torchvision.transforms.functional.adjust_gamma(right_img, random_gamma[1]) 61 | right_img = torchvision.transforms.functional.adjust_contrast(right_img, random_contrast[1]) 62 | right_img = np.asarray(right_img) 63 | left_img = np.asarray(left_img) 64 | 65 | # w, h = left_img.size 66 | # th, tw = 256, 512 67 | # 68 | # x1 = random.randint(0, w - tw) 69 | # y1 = random.randint(0, h - th) 70 | # 71 | # left_img = left_img.crop((x1, y1, x1 + tw, y1 + th)) 72 | # right_img = right_img.crop((x1, y1, x1 + tw, y1 + th)) 73 | # dataL = dataL[y1:y1 + th, x1:x1 + tw] 74 | # right_img = np.asarray(right_img) 75 | # left_img = np.asarray(left_img) 76 | 77 | # geometric unsymmetric-augmentation 78 | angle = 0; 79 | px = 0 80 | if np.random.binomial(1, 0.5): 81 | # angle = 0.1; 82 | # px = 2 83 | angle = 0.05 84 | px = 1 85 | co_transform = flow_transforms.Compose([ 86 | # flow_transforms.RandomVdisp(angle, px), 87 | # flow_transforms.Scale(np.random.uniform(self.rand_scale[0], self.rand_scale[1]), order=self.order), 88 | flow_transforms.RandomCrop((th, tw)), 89 | ]) 90 | augmented, disparity = co_transform([left_img, right_img], disparity) 91 | left_img = augmented[0] 92 | right_img = augmented[1] 93 | 94 | right_img.flags.writeable = True 95 | if np.random.binomial(1,0.2): 96 | sx = int(np.random.uniform(35,100)) 97 | sy = int(np.random.uniform(25,75)) 98 | cx = int(np.random.uniform(sx,right_img.shape[0]-sx)) 99 | cy = int(np.random.uniform(sy,right_img.shape[1]-sy)) 100 | right_img[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(right_img,0),0)[np.newaxis,np.newaxis] 101 | 102 | # to tensor, normalize 103 | disparity = np.ascontiguousarray(disparity, dtype=np.float32) 104 | processed = get_transform() 105 | left_img = processed(left_img) 106 | right_img = processed(right_img) 107 | 108 | return {"left": left_img, 109 | "right": right_img, 110 | "disparity": disparity} 111 | else: 112 | w, h = left_img.size 113 | 114 | # normalize 115 | processed = get_transform() 116 | left_img = processed(left_img).numpy() 117 | right_img = processed(right_img).numpy() 118 | 119 | # pad to size 1248x384 120 | top_pad = 384 - h 121 | right_pad = 1248 - w 122 | assert top_pad > 0 and right_pad > 0 123 | # pad images 124 | left_img = np.lib.pad(left_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='constant', constant_values=0) 125 | right_img = np.lib.pad(right_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='constant', 126 | constant_values=0) 127 | # pad disparity gt 128 | if disparity is not None: 129 | assert len(disparity.shape) == 2 130 | disparity = np.lib.pad(disparity, ((top_pad, 0), (0, right_pad)), mode='constant', constant_values=0) 131 | 132 | if disparity is not None: 133 | return {"left": left_img, 134 | "right": right_img, 135 | "disparity": disparity, 136 | "top_pad": top_pad, 137 | "right_pad": right_pad} 138 | else: 139 | return {"left": left_img, 140 | "right": right_img, 141 | "top_pad": top_pad, 142 | "right_pad": right_pad, 143 | "left_filename": self.left_filenames[index], 144 | "right_filename": self.right_filenames[index]} 145 | -------------------------------------------------------------------------------- /KITTI12/datasets/listfiles.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | import pdb 4 | from PIL import Image 5 | import os 6 | import os.path 7 | import numpy as np 8 | import glob 9 | 10 | 11 | def dataloader(filepath): 12 | img_list = [i.split('/')[-1] for i in glob.glob('%s/*'%filepath) if os.path.isdir(i)] 13 | 14 | left_train = ['%s/%s/im0.png'% (filepath,img) for img in img_list] 15 | right_train = ['%s/%s/im1.png'% (filepath,img) for img in img_list] 16 | disp_train_L = ['%s/%s/disp0GT.pfm' % (filepath,img) for img in img_list] 17 | disp_train_R = ['%s/%s/disp1GT.pfm' % (filepath,img) for img in img_list] 18 | 19 | return left_train, right_train, disp_train_L, disp_train_R 20 | -------------------------------------------------------------------------------- /KITTI12/datasets/readpfm.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import sys 4 | 5 | 6 | def readPFM(file): 7 | file = open(file, 'rb') 8 | 9 | color = None 10 | width = None 11 | height = None 12 | scale = None 13 | endian = None 14 | 15 | header = file.readline().rstrip() 16 | if (sys.version[0]) == '3': 17 | header = header.decode('utf-8') 18 | if header == 'PF': 19 | color = True 20 | elif header == 'Pf': 21 | color = False 22 | else: 23 | raise Exception('Not a PFM file.') 24 | 25 | if (sys.version[0]) == '3': 26 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 27 | else: 28 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline()) 29 | if dim_match: 30 | width, height = map(int, dim_match.groups()) 31 | else: 32 | raise Exception('Malformed PFM header.') 33 | 34 | if (sys.version[0]) == '3': 35 | scale = float(file.readline().rstrip().decode('utf-8')) 36 | else: 37 | scale = float(file.readline().rstrip()) 38 | 39 | if scale < 0: # little-endian 40 | endian = '<' 41 | scale = -scale 42 | else: 43 | endian = '>' # big-endian 44 | 45 | data = np.fromfile(file, endian + 'f') 46 | shape = (height, width, 3) if color else (height, width) 47 | 48 | data = np.reshape(data, shape) 49 | data = np.flipud(data) 50 | return data, scale 51 | 52 | -------------------------------------------------------------------------------- /KITTI12/datasets/sceneflow_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | import numpy as np 6 | from datasets.data_io import get_transform, read_all_lines, pfm_imread 7 | from . import flow_transforms 8 | import torchvision 9 | 10 | 11 | class SceneFlowDatset(Dataset): 12 | def __init__(self, datapath, list_filename, training): 13 | self.datapath = datapath 14 | self.left_filenames, self.right_filenames, self.disp_filenames = self.load_path(list_filename) 15 | self.training = training 16 | 17 | def load_path(self, list_filename): 18 | lines = read_all_lines(list_filename) 19 | splits = [line.split() for line in lines] 20 | left_images = [x[0] for x in splits] 21 | right_images = [x[1] for x in splits] 22 | if len(splits[0]) == 2: # ground truth not available 23 | return left_images, right_images, None 24 | else: 25 | disp_images = [x[2] for x in splits] 26 | return left_images, right_images, disp_images 27 | # disp_images = [x[2] for x in splits] 28 | # return left_images, right_images, disp_images 29 | 30 | def load_image(self, filename): 31 | return Image.open(filename).convert('RGB') 32 | 33 | def load_disp(self, filename): 34 | data, scale = pfm_imread(filename) 35 | data = np.ascontiguousarray(data, dtype=np.float32) 36 | return data 37 | 38 | def __len__(self): 39 | return len(self.left_filenames) 40 | 41 | def __getitem__(self, index): 42 | left_img = self.load_image(os.path.join(self.datapath, self.left_filenames[index])) 43 | right_img = self.load_image(os.path.join(self.datapath, self.right_filenames[index])) 44 | # disparity = self.load_disp(os.path.join(self.datapath, self.disp_filenames[index])) 45 | if self.disp_filenames: # has disparity ground truth 46 | disparity = self.load_disp(os.path.join(self.datapath, self.disp_filenames[index])) 47 | else: 48 | disparity = None 49 | if self.training: 50 | 51 | th, tw = 256, 512 52 | random_brightness = np.random.uniform(0.5, 2.0, 2) 53 | random_gamma = np.random.uniform(0.8, 1.2, 2) 54 | random_contrast = np.random.uniform(0.8, 1.2, 2) 55 | left_img = torchvision.transforms.functional.adjust_brightness(left_img, random_brightness[0]) 56 | left_img = torchvision.transforms.functional.adjust_gamma(left_img, random_gamma[0]) 57 | left_img = torchvision.transforms.functional.adjust_contrast(left_img, random_contrast[0]) 58 | right_img = torchvision.transforms.functional.adjust_brightness(right_img, random_brightness[1]) 59 | right_img = torchvision.transforms.functional.adjust_gamma(right_img, random_gamma[1]) 60 | right_img = torchvision.transforms.functional.adjust_contrast(right_img, random_contrast[1]) 61 | right_img = np.asarray(right_img) 62 | left_img = np.asarray(left_img) 63 | 64 | # w, h = left_img.size 65 | # th, tw = 256, 512 66 | # 67 | # x1 = random.randint(0, w - tw) 68 | # y1 = random.randint(0, h - th) 69 | # 70 | # left_img = left_img.crop((x1, y1, x1 + tw, y1 + th)) 71 | # right_img = right_img.crop((x1, y1, x1 + tw, y1 + th)) 72 | # dataL = dataL[y1:y1 + th, x1:x1 + tw] 73 | # right_img = np.asarray(right_img) 74 | # left_img = np.asarray(left_img) 75 | 76 | # geometric unsymmetric-augmentation 77 | angle = 0; 78 | px = 0 79 | if np.random.binomial(1, 0.5): 80 | # angle = 0.1; 81 | # px = 2 82 | angle = 0.05 83 | px = 1 84 | co_transform = flow_transforms.Compose([ 85 | # flow_transforms.RandomVdisp(angle, px), 86 | # flow_transforms.Scale(np.random.uniform(self.rand_scale[0], self.rand_scale[1]), order=self.order), 87 | flow_transforms.RandomCrop((th, tw)), 88 | ]) 89 | augmented, disparity = co_transform([left_img, right_img], disparity) 90 | left_img = augmented[0] 91 | right_img = augmented[1] 92 | 93 | # randomly occlude a region 94 | right_img.flags.writeable = True 95 | if np.random.binomial(1,0.5): 96 | sx = int(np.random.uniform(35,100)) 97 | sy = int(np.random.uniform(25,75)) 98 | cx = int(np.random.uniform(sx,right_img.shape[0]-sx)) 99 | cy = int(np.random.uniform(sy,right_img.shape[1]-sy)) 100 | right_img[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(right_img,0),0)[np.newaxis,np.newaxis] 101 | 102 | # w, h = left_img.size 103 | 104 | disparity = np.ascontiguousarray(disparity, dtype=np.float32) 105 | processed = get_transform() 106 | left_img = processed(left_img) 107 | right_img = processed(right_img) 108 | 109 | 110 | 111 | return {"left": left_img, 112 | "right": right_img, 113 | "disparity": disparity} 114 | else: 115 | if disparity is not None: 116 | w, h = left_img.size 117 | crop_w, crop_h = 960, 512 118 | 119 | left_img = left_img.crop((w - crop_w, h - crop_h, w, h)) 120 | right_img = right_img.crop((w - crop_w, h - crop_h, w, h)) 121 | disparity = disparity[h - crop_h:h, w - crop_w: w] 122 | 123 | processed = get_transform() 124 | left_img = processed(left_img) 125 | right_img = processed(right_img) 126 | 127 | return {"left": left_img, 128 | "right": right_img, 129 | "disparity": disparity, 130 | "top_pad": 0, 131 | "right_pad": 0} 132 | else: 133 | w, h = left_img.size 134 | # normalize 135 | processed = get_transform() 136 | left_img = processed(left_img).numpy() 137 | right_img = processed(right_img).numpy() 138 | 139 | # pad to size 1248x384 140 | top_pad = 32 - (h % 32) 141 | right_pad = 32 - (w % 32) 142 | assert top_pad > 0 and right_pad > 0 143 | # pad images 144 | left_img = np.lib.pad(left_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='constant', 145 | constant_values=0) 146 | right_img = np.lib.pad(right_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='constant', 147 | constant_values=0) 148 | return {"left": left_img, 149 | "right": right_img, 150 | "top_pad": top_pad, 151 | "right_pad": right_pad, 152 | "left_filename": self.left_filenames[index], 153 | "right_filename": self.right_filenames[index]} -------------------------------------------------------------------------------- /KITTI12/datasets/sceneflow_dataset.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/datasets/sceneflow_dataset.pyc -------------------------------------------------------------------------------- /KITTI12/filenames/kitti12_val.txt: -------------------------------------------------------------------------------- 1 | training/colored_0/000180_10.png training/colored_1/000180_10.png training/disp_occ/000180_10.png 2 | training/colored_0/000181_10.png training/colored_1/000181_10.png training/disp_occ/000181_10.png 3 | training/colored_0/000182_10.png training/colored_1/000182_10.png training/disp_occ/000182_10.png 4 | training/colored_0/000183_10.png training/colored_1/000183_10.png training/disp_occ/000183_10.png 5 | training/colored_0/000184_10.png training/colored_1/000184_10.png training/disp_occ/000184_10.png 6 | training/colored_0/000185_10.png training/colored_1/000185_10.png training/disp_occ/000185_10.png 7 | training/colored_0/000186_10.png training/colored_1/000186_10.png training/disp_occ/000186_10.png 8 | training/colored_0/000187_10.png training/colored_1/000187_10.png training/disp_occ/000187_10.png 9 | training/colored_0/000188_10.png training/colored_1/000188_10.png training/disp_occ/000188_10.png 10 | training/colored_0/000189_10.png training/colored_1/000189_10.png training/disp_occ/000189_10.png 11 | training/colored_0/000190_10.png training/colored_1/000190_10.png training/disp_occ/000190_10.png 12 | training/colored_0/000191_10.png training/colored_1/000191_10.png training/disp_occ/000191_10.png 13 | training/colored_0/000192_10.png training/colored_1/000192_10.png training/disp_occ/000192_10.png 14 | training/colored_0/000193_10.png training/colored_1/000193_10.png training/disp_occ/000193_10.png 15 | -------------------------------------------------------------------------------- /KITTI12/models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.pwcnet import PWCNet_G, PWCNet_GC 2 | from models.pwcnet_ddim import PWCNet_ddimgc 3 | from models.loss import model_loss 4 | 5 | __models__ = { 6 | "gwcnet-g": PWCNet_G, 7 | "gwcnet-gc": PWCNet_GC, 8 | "pwc_ddimgc": PWCNet_ddimgc 9 | } 10 | -------------------------------------------------------------------------------- /KITTI12/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI12/models/__pycache__/gwcnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/models/__pycache__/gwcnet.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/models/__pycache__/head.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/models/__pycache__/head.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI12/models/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/models/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/models/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/models/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI12/models/__pycache__/pwcnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/models/__pycache__/pwcnet.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/models/__pycache__/pwcnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/models/__pycache__/pwcnet.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI12/models/__pycache__/pwcnet_ddim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/models/__pycache__/pwcnet_ddim.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI12/models/__pycache__/submodule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/models/__pycache__/submodule.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/models/__pycache__/submodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/models/__pycache__/submodule.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI12/models/head.py: -------------------------------------------------------------------------------- 1 | """ 2 | DiffusionDet Transformer class. 3 | 4 | Copy-paste from torch.nn.Transformer with modifications: 5 | * positional encodings are passed in MHattention 6 | * extra LN at the end of encoder is removed 7 | * decoder returns a stack of activations from all decoding layers 8 | """ 9 | import copy 10 | import math 11 | 12 | import numpy as np 13 | import torch 14 | from torch import nn, Tensor 15 | import torch.nn.functional as F 16 | 17 | 18 | 19 | _DEFAULT_SCALE_CLAMP = math.log(100000.0 / 16) 20 | 21 | 22 | class SinusoidalPositionEmbeddings(nn.Module): 23 | def __init__(self, dim): 24 | super().__init__() 25 | self.dim = dim 26 | 27 | def forward(self, time): 28 | device = time.device 29 | half_dim = self.dim // 2 30 | embeddings = math.log(10000) / (half_dim - 1) 31 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 32 | embeddings = time[:, None] * embeddings[None, :] 33 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 34 | return embeddings 35 | 36 | 37 | class GaussianFourierProjection(nn.Module): 38 | """Gaussian random features for encoding time steps.""" 39 | 40 | def __init__(self, embed_dim, scale=30.): 41 | super().__init__() 42 | # Randomly sample weights during initialization. These weights are fixed 43 | # during optimization and are not trainable. 44 | self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False) 45 | 46 | def forward(self, x): 47 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 48 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 49 | 50 | 51 | class DynamicHead(nn.Module): 52 | 53 | def __init__(self, d_model): 54 | super().__init__() 55 | self.d_model = d_model 56 | time_dim = d_model * 4 57 | self.time_mlp = nn.Sequential( 58 | SinusoidalPositionEmbeddings(d_model), 59 | nn.Linear(d_model, time_dim), 60 | nn.GELU(), 61 | nn.Linear(time_dim, time_dim), 62 | ) 63 | self.block_time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(d_model * 4, d_model)) 64 | #self.block_time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(d_model * 4, d_model), nn.Sigmoid()) 65 | 66 | self._reset_parameters() 67 | 68 | def _reset_parameters(self): 69 | # init all parameters. 70 | for p in self.parameters(): 71 | if p.dim() > 1: 72 | nn.init.xavier_uniform_(p) 73 | 74 | def forward(self, noisy, t): 75 | time_emb = self.time_mlp(t) 76 | scale_shift = self.block_time_mlp(time_emb).unsqueeze(-1).unsqueeze(-1) 77 | noisy = noisy + scale_shift 78 | #noisy = noisy * scale_shift 79 | # scale, shift = scale_shift.chunk(2, dim=1) 80 | # volume = volume * (scale + 1) + shift 81 | 82 | return noisy -------------------------------------------------------------------------------- /KITTI12/models/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | def model_loss(disp_ests, disp_gt, mask): 5 | weights = [0.5, 0.5, 0.5, 0.7, 1.0, 1.3] 6 | all_losses = [] 7 | for disp_est, weight in zip(disp_ests, weights): 8 | all_losses.append(weight * F.smooth_l1_loss(disp_est[mask], disp_gt[mask], size_average=True)) 9 | return sum(all_losses) 10 | 11 | 12 | -------------------------------------------------------------------------------- /KITTI12/models/relu/submodule.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.data 5 | from torch.autograd import Variable 6 | from torch.autograd.function import Function 7 | import torch.nn.functional as F 8 | import numpy as np 9 | 10 | 11 | class Mish(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | print("Mish activation loaded...") 15 | 16 | def forward(self, x): 17 | #save 1 second per epoch with no x= x*() and then return x...just inline it. 18 | return x *( torch.tanh(F.softplus(x))) 19 | 20 | 21 | def convbn(in_channels, out_channels, kernel_size, stride, pad, dilation): 22 | return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 23 | padding=dilation if dilation > 1 else pad, dilation=dilation, bias=False), 24 | nn.BatchNorm2d(out_channels)) 25 | 26 | 27 | def convbn_3d(in_channels, out_channels, kernel_size, stride, pad): 28 | return nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 29 | padding=pad, bias=False), 30 | nn.BatchNorm3d(out_channels)) 31 | 32 | 33 | def disparity_regression(x, maxdisp): 34 | assert len(x.shape) == 4 35 | disp_values = torch.arange(0, maxdisp, dtype=x.dtype, device=x.device) 36 | disp_values = disp_values.view(1, maxdisp, 1, 1) 37 | return torch.sum(x * disp_values, 1, keepdim=False) 38 | 39 | 40 | def disp_regression_nearby(similarity, disp_step, half_support_window=2): 41 | """Returns predicted disparity with subpixel_map(disp_similarity). 42 | 43 | Predicted disparity is computed as: 44 | 45 | d_predicted = sum_d( d * P_predicted(d)), 46 | where | d - d_similarity_maximum | < half_size 47 | 48 | Args: 49 | similarity: Tensor with similarities with indices 50 | [example_index, disparity_index, y, x]. 51 | disp_step: disparity difference between near-by 52 | disparity indices in "similarities" tensor. 53 | half_support_window: defines size of disparity window in pixels 54 | around disparity with maximum similarity, 55 | which is used to convert similarities 56 | to probabilities and compute mean. 57 | """ 58 | 59 | assert 4 == similarity.dim(), \ 60 | 'Similarity should 4D Tensor,but get {}D Tensor'.format(similarity.dim()) 61 | 62 | # In every location (x, y) find disparity with maximum similarity score. 63 | similar_maximum, idx_maximum = torch.max(similarity, dim=1, keepdim=True) 64 | idx_limit = similarity.size(1) - 1 65 | 66 | # Collect similarity scores for the disparities around the disparity 67 | # with the maximum similarity score. 68 | support_idx_disp = [] 69 | for idx_shift in range(-half_support_window, half_support_window + 1): 70 | idx_disp = idx_maximum + idx_shift 71 | idx_disp[idx_disp < 0] = 0 72 | idx_disp[idx_disp >= idx_limit] = idx_limit 73 | support_idx_disp.append(idx_disp) 74 | 75 | support_idx_disp = torch.cat(support_idx_disp, dim=1) 76 | support_similar = torch.gather(similarity, 1, support_idx_disp.long()) 77 | support_disp = support_idx_disp.float() * disp_step 78 | 79 | # Convert collected similarity scores to the disparity distribution 80 | # using softmax and compute disparity as a mean of this distribution. 81 | prob = F.softmax(support_similar, dim=1) 82 | disp = torch.sum(prob * support_disp.float(), dim=1) 83 | 84 | return disp 85 | 86 | def build_concat_volume(refimg_fea, targetimg_fea, maxdisp): 87 | B, C, H, W = refimg_fea.shape 88 | volume = refimg_fea.new_zeros([B, 2 * C, maxdisp, H, W]) 89 | for i in range(maxdisp): 90 | if i > 0: 91 | volume[:, :C, i, :, i:] = refimg_fea[:, :, :, i:] 92 | volume[:, C:, i, :, i:] = targetimg_fea[:, :, :, :-i] 93 | else: 94 | volume[:, :C, i, :, :] = refimg_fea 95 | volume[:, C:, i, :, :] = targetimg_fea 96 | volume = volume.contiguous() 97 | return volume 98 | 99 | 100 | def groupwise_correlation(fea1, fea2, num_groups): 101 | B, C, H, W = fea1.shape 102 | assert C % num_groups == 0 103 | channels_per_group = C // num_groups 104 | cost = (fea1 * fea2).view([B, num_groups, channels_per_group, H, W]).mean(dim=2) 105 | assert cost.shape == (B, num_groups, H, W) 106 | return cost 107 | 108 | 109 | def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups): 110 | B, C, H, W = refimg_fea.shape 111 | volume = refimg_fea.new_zeros([B, num_groups, maxdisp, H, W]) 112 | for i in range(maxdisp): 113 | if i > 0: 114 | volume[:, :, i, :, i:] = groupwise_correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i], 115 | num_groups) 116 | else: 117 | volume[:, :, i, :, :] = groupwise_correlation(refimg_fea, targetimg_fea, num_groups) 118 | volume = volume.contiguous() 119 | return volume 120 | 121 | def build_corrleation_volume(refimg_fea, targetimg_fea, maxdisp, num_groups): 122 | B, C, H, W = refimg_fea.shape 123 | volume = refimg_fea.new_zeros([B, num_groups, 2 * maxdisp + 1, H, W]) 124 | for i in range(-maxdisp, maxdisp+1): 125 | if i > 0: 126 | volume[:, :, i + maxdisp, :, i:] = groupwise_correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i], 127 | num_groups) 128 | elif i < 0: 129 | volume[:, :, i + maxdisp, :, :-i] = groupwise_correlation(refimg_fea[:, :, :, :-i], 130 | targetimg_fea[:, :, :, i:], 131 | num_groups) 132 | else: 133 | volume[:, :, i + maxdisp, :, :] = groupwise_correlation(refimg_fea, targetimg_fea, num_groups) 134 | volume = volume.contiguous() 135 | return volume 136 | 137 | def warp(x, disp): 138 | """ 139 | warp an image/tensor (imright) back to imleft, according to the disp 140 | 141 | x: [B, C, H, W] (imright) 142 | disp: [B, 1, H, W] disp 143 | 144 | """ 145 | B, C, H, W = x.size() 146 | device = x.get_device() 147 | # mesh grid 148 | xx = torch.arange(0, W, device=device).view(1, -1).repeat(H, 1) 149 | yy = torch.arange(0, H, device=device).view(-1, 1).repeat(1, W) 150 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1) 151 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1) 152 | xx = xx.float() 153 | yy = yy.float() 154 | # grid = torch.cat((xx, yy), 1).float() 155 | 156 | # if x.is_cuda: 157 | # xx = xx.float().cuda() 158 | # yy = yy.float().cuda() 159 | xx_warp = Variable(xx) - disp 160 | yy = Variable(yy) 161 | # xx_warp = xx - disp 162 | vgrid = torch.cat((xx_warp, yy), 1) 163 | # vgrid = Variable(grid) + flo 164 | # scale grid to [-1,1] 165 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 166 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 167 | 168 | vgrid = vgrid.permute(0, 2, 3, 1) 169 | output = nn.functional.grid_sample(x, vgrid) 170 | mask = torch.ones(x.size(), device=device, requires_grad=True) 171 | mask = nn.functional.grid_sample(mask, vgrid) 172 | 173 | mask[mask < 0.999] = 0 174 | mask[mask > 0] = 1 175 | 176 | return output * mask 177 | 178 | def FMish(x): 179 | 180 | ''' 181 | 182 | Applies the mish function element-wise: 183 | 184 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) 185 | 186 | See additional documentation for mish class. 187 | 188 | ''' 189 | 190 | return x * torch.tanh(F.softplus(x)) 191 | 192 | class BasicBlock(nn.Module): 193 | expansion = 1 194 | 195 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation): 196 | super(BasicBlock, self).__init__() 197 | 198 | self.conv1 = nn.Sequential(convbn(inplanes, planes, 3, stride, pad, dilation), 199 | nn.ReLU(inplace=True)) 200 | 201 | self.conv2 = convbn(planes, planes, 3, 1, pad, dilation) 202 | 203 | self.downsample = downsample 204 | self.stride = stride 205 | 206 | def forward(self, x): 207 | out = self.conv1(x) 208 | out = self.conv2(out) 209 | 210 | if self.downsample is not None: 211 | x = self.downsample(x) 212 | 213 | out += x 214 | 215 | return out 216 | -------------------------------------------------------------------------------- /KITTI12/models/submodule.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.data 5 | from torch.autograd import Variable 6 | from torch.autograd.function import Function 7 | import torch.nn.functional as F 8 | import numpy as np 9 | 10 | 11 | class Mish(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | #print("Mish activation loaded...") 15 | 16 | def forward(self, x): 17 | #save 1 second per epoch with no x= x*() and then return x...just inline it. 18 | return x *( torch.tanh(F.softplus(x))) 19 | 20 | 21 | def convbn(in_channels, out_channels, kernel_size, stride, pad, dilation): 22 | return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 23 | padding=dilation if dilation > 1 else pad, dilation=dilation, bias=False), 24 | nn.BatchNorm2d(out_channels)) 25 | 26 | 27 | def convbn_3d(in_channels, out_channels, kernel_size, stride, pad): 28 | return nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 29 | padding=pad, bias=False), 30 | nn.BatchNorm3d(out_channels)) 31 | 32 | 33 | def disparity_regression(x, maxdisp): 34 | assert len(x.shape) == 4 35 | disp_values = torch.arange(0, maxdisp, dtype=x.dtype, device=x.device) 36 | disp_values = disp_values.view(1, maxdisp, 1, 1) 37 | return torch.sum(x * disp_values, 1, keepdim=False) 38 | 39 | 40 | def disp_regression_nearby(similarity, disp_step, half_support_window=2): 41 | """Returns predicted disparity with subpixel_map(disp_similarity). 42 | 43 | Predicted disparity is computed as: 44 | 45 | d_predicted = sum_d( d * P_predicted(d)), 46 | where | d - d_similarity_maximum | < half_size 47 | 48 | Args: 49 | similarity: Tensor with similarities with indices 50 | [example_index, disparity_index, y, x]. 51 | disp_step: disparity difference between near-by 52 | disparity indices in "similarities" tensor. 53 | half_support_window: defines size of disparity window in pixels 54 | around disparity with maximum similarity, 55 | which is used to convert similarities 56 | to probabilities and compute mean. 57 | """ 58 | 59 | assert 4 == similarity.dim(), \ 60 | 'Similarity should 4D Tensor,but get {}D Tensor'.format(similarity.dim()) 61 | 62 | # In every location (x, y) find disparity with maximum similarity score. 63 | similar_maximum, idx_maximum = torch.max(similarity, dim=1, keepdim=True) 64 | idx_limit = similarity.size(1) - 1 65 | 66 | # Collect similarity scores for the disparities around the disparity 67 | # with the maximum similarity score. 68 | support_idx_disp = [] 69 | for idx_shift in range(-half_support_window, half_support_window + 1): 70 | idx_disp = idx_maximum + idx_shift 71 | idx_disp[idx_disp < 0] = 0 72 | idx_disp[idx_disp >= idx_limit] = idx_limit 73 | support_idx_disp.append(idx_disp) 74 | 75 | support_idx_disp = torch.cat(support_idx_disp, dim=1) 76 | support_similar = torch.gather(similarity, 1, support_idx_disp.long()) 77 | support_disp = support_idx_disp.float() * disp_step 78 | 79 | # Convert collected similarity scores to the disparity distribution 80 | # using softmax and compute disparity as a mean of this distribution. 81 | prob = F.softmax(support_similar, dim=1) 82 | disp = torch.sum(prob * support_disp.float(), dim=1) 83 | 84 | return disp 85 | 86 | def build_concat_volume(refimg_fea, targetimg_fea, maxdisp): 87 | B, C, H, W = refimg_fea.shape 88 | volume = refimg_fea.new_zeros([B, 2 * C, maxdisp, H, W]) 89 | for i in range(maxdisp): 90 | if i > 0: 91 | volume[:, :C, i, :, i:] = refimg_fea[:, :, :, i:] 92 | volume[:, C:, i, :, i:] = targetimg_fea[:, :, :, :-i] 93 | else: 94 | volume[:, :C, i, :, :] = refimg_fea 95 | volume[:, C:, i, :, :] = targetimg_fea 96 | volume = volume.contiguous() 97 | return volume 98 | 99 | 100 | def groupwise_correlation(fea1, fea2, num_groups): 101 | B, C, H, W = fea1.shape 102 | assert C % num_groups == 0 103 | channels_per_group = C // num_groups 104 | cost = (fea1 * fea2).view([B, num_groups, channels_per_group, H, W]).mean(dim=2) 105 | assert cost.shape == (B, num_groups, H, W) 106 | return cost 107 | 108 | 109 | def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups): 110 | B, C, H, W = refimg_fea.shape 111 | volume = refimg_fea.new_zeros([B, num_groups, maxdisp, H, W]) 112 | for i in range(maxdisp): 113 | if i > 0: 114 | volume[:, :, i, :, i:] = groupwise_correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i], 115 | num_groups) 116 | else: 117 | volume[:, :, i, :, :] = groupwise_correlation(refimg_fea, targetimg_fea, num_groups) 118 | volume = volume.contiguous() 119 | return volume 120 | 121 | def build_corrleation_volume(refimg_fea, targetimg_fea, maxdisp, num_groups): 122 | B, C, H, W = refimg_fea.shape 123 | volume = refimg_fea.new_zeros([B, num_groups, 2 * maxdisp + 1, H, W]) 124 | for i in range(-maxdisp, maxdisp+1): 125 | if i > 0: 126 | volume[:, :, i + maxdisp, :, i:] = groupwise_correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i], 127 | num_groups) 128 | elif i < 0: 129 | volume[:, :, i + maxdisp, :, :-i] = groupwise_correlation(refimg_fea[:, :, :, :-i], 130 | targetimg_fea[:, :, :, i:], 131 | num_groups) 132 | else: 133 | volume[:, :, i + maxdisp, :, :] = groupwise_correlation(refimg_fea, targetimg_fea, num_groups) 134 | volume = volume.contiguous() 135 | return volume 136 | 137 | def warp(x, disp): 138 | """ 139 | warp an image/tensor (imright) back to imleft, according to the disp 140 | 141 | x: [B, C, H, W] (imright) 142 | disp: [B, 1, H, W] disp 143 | 144 | """ 145 | B, C, H, W = x.size() 146 | device = x.get_device() 147 | # mesh grid 148 | xx = torch.arange(0, W, device=device).view(1, -1).repeat(H, 1) 149 | yy = torch.arange(0, H, device=device).view(-1, 1).repeat(1, W) 150 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1) 151 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1) 152 | xx = xx.float() 153 | yy = yy.float() 154 | # grid = torch.cat((xx, yy), 1).float() 155 | 156 | # if x.is_cuda: 157 | # xx = xx.float().cuda() 158 | # yy = yy.float().cuda() 159 | xx_warp = Variable(xx) - disp 160 | yy = Variable(yy) 161 | # xx_warp = xx - disp 162 | vgrid = torch.cat((xx_warp, yy), 1) 163 | # vgrid = Variable(grid) + flo 164 | # scale grid to [-1,1] 165 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 166 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 167 | 168 | vgrid = vgrid.permute(0, 2, 3, 1) 169 | output = nn.functional.grid_sample(x, vgrid) 170 | mask = torch.ones(x.size(), device=device, requires_grad=True) 171 | mask = nn.functional.grid_sample(mask, vgrid) 172 | 173 | mask[mask < 0.999] = 0 174 | mask[mask > 0] = 1 175 | 176 | return output * mask 177 | 178 | def FMish(x): 179 | 180 | ''' 181 | 182 | Applies the mish function element-wise: 183 | 184 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) 185 | 186 | See additional documentation for mish class. 187 | 188 | ''' 189 | 190 | return x * torch.tanh(F.softplus(x)) 191 | 192 | class BasicBlock(nn.Module): 193 | expansion = 1 194 | 195 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation): 196 | super(BasicBlock, self).__init__() 197 | 198 | self.conv1 = nn.Sequential(convbn(inplanes, planes, 3, stride, pad, dilation), 199 | Mish()) 200 | 201 | self.conv2 = convbn(planes, planes, 3, 1, pad, dilation) 202 | 203 | self.downsample = downsample 204 | self.stride = stride 205 | 206 | def forward(self, x): 207 | out = self.conv1(x) 208 | out = self.conv2(out) 209 | 210 | if self.downsample is not None: 211 | x = self.downsample(x) 212 | 213 | out += x 214 | 215 | return out 216 | -------------------------------------------------------------------------------- /KITTI12/save_disp_sceneflow_kitti12.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import argparse 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim as optim 9 | import torch.utils.data 10 | from torch.autograd import Variable 11 | import torchvision.utils as vutils 12 | import torch.nn.functional as F 13 | import numpy as np 14 | import time 15 | # from tensorboardX import SummaryWriter 16 | from datasets import __datasets__ 17 | from models import __models__ 18 | from utils import * 19 | from torch.utils.data import DataLoader 20 | import gc 21 | import matplotlib.pyplot as plt 22 | import skimage 23 | import skimage.io 24 | import cv2 25 | 26 | # cudnn.benchmark = True 27 | 28 | os.environ['CUDA_VISIBLE_DEVICES'] = '7' 29 | 30 | parser = argparse.ArgumentParser( 31 | description='Attention Concatenation Volume for Accurate and Efficient Stereo Matching (ACVNet)') 32 | parser.add_argument('--model', default='pwc_ddimgc', help='select a model structure', choices=__models__.keys()) 33 | parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity') 34 | parser.add_argument('--dataset', default='kitti', help='dataset name', choices=__datasets__.keys()) 35 | parser.add_argument('--datapath', default="/home/zhengdian/dataset/KITTI/2012/", help='data path') 36 | parser.add_argument('--test_batch_size', type=int, default=1, help='testing batch size') 37 | parser.add_argument('--testlist', default='./filenames/test_temp.txt', help='testing list') 38 | parser.add_argument('--loadckpt', default='./checkpoints/kitti12/test_all/checkpoint_000244.ckpt') 39 | # parse arguments 40 | args = parser.parse_args() 41 | 42 | # dataset, dataloader 43 | StereoDataset = __datasets__[args.dataset] 44 | test_dataset = StereoDataset(args.datapath, args.testlist, False) 45 | TestImgLoader = DataLoader(test_dataset, args.test_batch_size, shuffle=False, num_workers=4, drop_last=False) 46 | 47 | # model, optimizer 48 | model = __models__[args.model](args.maxdisp) 49 | model = nn.DataParallel(model) 50 | model.cuda() 51 | 52 | model_origin = __models__['gwcnet-gc'](args.maxdisp) 53 | model_origin = nn.DataParallel(model_origin) 54 | model_origin.cuda() 55 | 56 | # load parameters 57 | print("loading model {}".format(args.loadckpt)) 58 | state_dict = torch.load(args.loadckpt) 59 | model.load_state_dict(state_dict['model']) 60 | 61 | state_dict = torch.load('./PCWNet_kitti12_best.ckpt') 62 | model_origin.load_state_dict(state_dict['model']) 63 | 64 | save_dir = './speed_test/' 65 | 66 | 67 | def test(): 68 | os.makedirs(save_dir, exist_ok=True) 69 | for batch_idx, sample in enumerate(TestImgLoader): 70 | torch.cuda.synchronize() 71 | start_time = time.time() 72 | # disp_est_ = test_sample(sample) 73 | # for i in range(len(disp_est_)): 74 | # disp_est_np = tensor2numpy(disp_est_[i]).squeeze(0) 75 | # torch.cuda.synchronize() 76 | # print('Iter {}/{}, time = {:3f}'.format(batch_idx, len(TestImgLoader), 77 | # time.time() - start_time)) 78 | # left_filenames = sample["left_filename"] 79 | # top_pad_np = tensor2numpy(sample["top_pad"]) 80 | # right_pad_np = tensor2numpy(sample["right_pad"]) 81 | # 82 | # for disp_est, top_pad, right_pad, fn in zip(disp_est_np, top_pad_np, right_pad_np, left_filenames): 83 | # assert len(disp_est.shape) == 2 84 | # disp_est = np.array(disp_est[top_pad:, :-right_pad], dtype=np.float32) 85 | # # disp_est = np.array(disp_est, dtype=np.float32) 86 | # fn = os.path.join(save_dir, fn.split('/')[-1]) 87 | # print("saving to", fn, disp_est.shape) 88 | # disp_est_uint = np.round(disp_est * 256).astype(np.uint16) 89 | # # skimage.io.imsave(fn, disp_est_uint) 90 | # plt.imsave(str(i)+'.png', disp_est_uint, cmap='jet') 91 | disp_est_np = tensor2numpy(test_sample(sample)) 92 | torch.cuda.synchronize() 93 | print('Iter {}/{}, time = {:3f}'.format(batch_idx, len(TestImgLoader), 94 | time.time() - start_time)) 95 | left_filenames = sample["left_filename"] 96 | top_pad_np = tensor2numpy(sample["top_pad"]) 97 | right_pad_np = tensor2numpy(sample["right_pad"]) 98 | 99 | for disp_est, top_pad, right_pad, fn in zip(disp_est_np, top_pad_np, right_pad_np, left_filenames): 100 | assert len(disp_est.shape) == 2 101 | disp_est = np.array(disp_est[top_pad:, :-right_pad], dtype=np.float32) 102 | #disp_est = np.array(disp_est, dtype=np.float32) 103 | fn = os.path.join(save_dir, fn.split('/')[-1]) 104 | print("saving to", fn, disp_est.shape) 105 | disp_est_uint = np.round(disp_est * 256).astype(np.uint16) 106 | #skimage.io.imsave(fn, disp_est_uint) 107 | plt.imsave('a.png', disp_est_uint, cmap='jet') 108 | #cv2.imwrite(fn, cv2.applyColorMap(cv2.convertScaleAbs(disp_est_uint, alpha=0.01), cv2.COLORMAP_JET)) 109 | 110 | 111 | # test one sample 112 | @make_nograd_func 113 | def test_sample(sample): 114 | model.eval() 115 | model_origin.eval() 116 | imgL, imgR, filename = sample['left'], sample['right'], sample['left_filename'] 117 | imgL = imgL.cuda() 118 | imgR = imgR.cuda() 119 | 120 | # disp_ests, qwe = model_origin(imgL, imgR) 121 | disp_, qwe = model_origin(imgL, imgR) 122 | disp_ = disp_[-1] 123 | disp_net = torch.clamp(disp_, 0, args.maxdisp - 1).unsqueeze(1) 124 | 125 | b, c, h, w = disp_net.shape 126 | disp_net = F.interpolate(disp_net, size=(h // 4, w // 4), mode='bilinear') / 4 127 | 128 | disp_ests, qwe = model(imgL, imgR, disp_, disp_net, None) 129 | 130 | 131 | return disp_ests[-1] 132 | 133 | 134 | if __name__ == '__main__': 135 | test() 136 | -------------------------------------------------------------------------------- /KITTI12/scripts/kitti12.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -x 3 | DATAPATH="/home/zhengdian/dataset/KITTI/2012/" 4 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --dataset kitti \ 5 | --datapath $DATAPATH --trainlist ./filenames/kitti12_train.txt --testlist ./filenames/kitti12_val.txt \ 6 | --epochs 300 --lr 0.001 --batch_size 4 --lrepochs "200:10" \ 7 | --model pcw_ddim --logdir ./checkpoints/kitti12/test \ 8 | --test_batch_size 12 -------------------------------------------------------------------------------- /KITTI12/test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import argparse 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim as optim 9 | import torch.utils.data 10 | from torch.autograd import Variable 11 | import torchvision.utils as vutils 12 | import torch.nn.functional as F 13 | import numpy as np 14 | import time 15 | from datasets import __datasets__ 16 | from models import __models__, model_loss 17 | from utils import * 18 | from torch.utils.data import DataLoader 19 | import gc 20 | import skimage.io 21 | 22 | cudnn.benchmark = True 23 | os.environ['CUDA_VISIBLE_DEVICES'] = '6' 24 | parser = argparse.ArgumentParser(description='PCW-Net: Pyramid Combination and Warping Cost Volume for Stereo Matching') 25 | parser.add_argument('--model', default='pwc_ddimgc', help='select a model structure', choices=__models__.keys()) 26 | parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity') 27 | parser.add_argument('--test_batchsize', type=int, default=1) 28 | parser.add_argument('--dataset', default='kitti', help='dataset name', choices=__datasets__.keys()) 29 | parser.add_argument('--datapath', default="/mnt/Datasets/KITTI/2012/", help='data path') 30 | parser.add_argument('--testlist', default="./filenames/kitti12_all.txt", help='testing list') 31 | parser.add_argument('--loadckpt', default="./checkpoints/our_best.ckpt", 32 | help='load the weights from a specific checkpoint') 33 | 34 | # parse arguments 35 | args = parser.parse_args() 36 | 37 | # dataset, dataloader 38 | StereoDataset = __datasets__[args.dataset] 39 | test_dataset = StereoDataset(args.datapath, args.testlist, False) 40 | TestImgLoader = DataLoader(test_dataset, args.test_batchsize, shuffle=False, num_workers=4, drop_last=False) 41 | 42 | # model, optimizer 43 | model = __models__[args.model](args.maxdisp) 44 | model = nn.DataParallel(model) 45 | model.cuda() 46 | 47 | # load parameters 48 | print("loading model {}".format(args.loadckpt)) 49 | state_dict = torch.load(args.loadckpt) 50 | model.load_state_dict(state_dict['model']) 51 | 52 | model_origin = __models__['gwcnet-gc'](args.maxdisp) 53 | model_origin = nn.DataParallel(model_origin) 54 | model_origin.cuda() 55 | state_dict = torch.load("./checkpoints/origin.ckpt") 56 | model_origin.load_state_dict(state_dict['model']) 57 | 58 | 59 | def test(): 60 | avg_test_scalars = AverageMeterDict() 61 | for batch_idx, sample in enumerate(TestImgLoader): 62 | start_time = time.time() 63 | loss, scalar_outputs = test_sample(sample, compute_metrics=True) 64 | avg_test_scalars.update(scalar_outputs) 65 | del scalar_outputs 66 | print('Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(batch_idx, 67 | len(TestImgLoader), loss, 68 | time.time() - start_time)) 69 | avg_test_scalars = avg_test_scalars.mean() 70 | print("avg_test_scalars", avg_test_scalars) 71 | gc.collect() 72 | 73 | 74 | # test one sample 75 | @make_nograd_func 76 | def test_sample(sample, compute_metrics=True): 77 | model.eval() 78 | model_origin.eval() 79 | imgL, imgR, disp_gt = sample['left'], sample['right'], sample['disparity'] 80 | imgL = imgL.cuda() 81 | imgR = imgR.cuda() 82 | disp_gt = disp_gt.cuda() 83 | 84 | # disp_ests, qwe = model_origin(imgL, imgR) 85 | 86 | disp_, _ = model_origin(imgL, imgR) 87 | disp_ = disp_[-1] 88 | disp_net = torch.clamp(disp_, 0, args.maxdisp - 1).unsqueeze(1) 89 | b, c, h, w = disp_net.shape 90 | disp_net = F.interpolate(disp_net, size=(h // 4, w // 4), mode='bilinear') / 4 91 | 92 | disp_ests, pred3 = model(imgL, imgR, disp_, disp_net, None) 93 | 94 | mask = (disp_gt < args.maxdisp) & (disp_gt > 0) 95 | loss = model_loss(disp_ests, disp_gt, mask) 96 | 97 | scalar_outputs = {"loss": loss} 98 | #image_outputs = {"disp_est": disp_ests, "disp_gt": disp_gt, "imgL": imgL, "imgR": imgR} 99 | 100 | scalar_outputs["D1"] = [D1_metric(disp_est, disp_gt, mask) for disp_est in disp_ests] 101 | #scalar_outputs["D1_pred3"] = [D1_metric(pred, disp_gt, mask) for pred in pred3] 102 | scalar_outputs["EPE"] = [EPE_metric(disp_est, disp_gt, mask) for disp_est in disp_ests] 103 | scalar_outputs["Thres1"] = [Thres_metric(disp_est, disp_gt, mask, 1.0) for disp_est in disp_ests] 104 | scalar_outputs["Thres2"] = [Thres_metric(disp_est, disp_gt, mask, 2.0) for disp_est in disp_ests] 105 | scalar_outputs["Thres3"] = [Thres_metric(disp_est, disp_gt, mask, 3.0) for disp_est in disp_ests] 106 | 107 | # if compute_metrics: 108 | # image_outputs["errormap"] = [disp_error_image_func()(disp_est, disp_gt) for disp_est in disp_ests] 109 | 110 | return tensor2float(loss), tensor2float(scalar_outputs)#, image_outputs 111 | 112 | 113 | if __name__ == '__main__': 114 | test() 115 | -------------------------------------------------------------------------------- /KITTI12/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.experiment import * 2 | from utils.visualization import * 3 | from utils.metrics import D1_metric, Thres_metric, EPE_metric -------------------------------------------------------------------------------- /KITTI12/utils/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/utils/__init__.pyc -------------------------------------------------------------------------------- /KITTI12/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI12/utils/__pycache__/experiment.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/utils/__pycache__/experiment.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/utils/__pycache__/experiment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/utils/__pycache__/experiment.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI12/utils/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/utils/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI12/utils/__pycache__/visualization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/utils/__pycache__/visualization.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI12/utils/__pycache__/visualization.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/utils/__pycache__/visualization.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI12/utils/experiment.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.parallel 5 | import torch.utils.data 6 | from torch.autograd import Variable 7 | import torchvision.utils as vutils 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import copy 11 | 12 | 13 | def make_iterative_func(func): 14 | def wrapper(vars): 15 | if isinstance(vars, list): 16 | return [wrapper(x) for x in vars] 17 | elif isinstance(vars, tuple): 18 | return tuple([wrapper(x) for x in vars]) 19 | elif isinstance(vars, dict): 20 | return {k: wrapper(v) for k, v in vars.items()} 21 | else: 22 | return func(vars) 23 | 24 | return wrapper 25 | 26 | 27 | def make_nograd_func(func): 28 | def wrapper(*f_args, **f_kwargs): 29 | with torch.no_grad(): 30 | ret = func(*f_args, **f_kwargs) 31 | return ret 32 | 33 | return wrapper 34 | 35 | 36 | @make_iterative_func 37 | def tensor2float(vars): 38 | if isinstance(vars, float): 39 | return vars 40 | elif isinstance(vars, torch.Tensor): 41 | return vars.data.item() 42 | else: 43 | raise NotImplementedError("invalid input type for tensor2float") 44 | 45 | 46 | @make_iterative_func 47 | def tensor2numpy(vars): 48 | if isinstance(vars, np.ndarray): 49 | return vars 50 | elif isinstance(vars, torch.Tensor): 51 | return vars.data.cpu().numpy() 52 | else: 53 | raise NotImplementedError("invalid input type for tensor2numpy") 54 | 55 | 56 | @make_iterative_func 57 | def check_allfloat(vars): 58 | assert isinstance(vars, float) 59 | 60 | 61 | def save_scalars(logger, mode_tag, scalar_dict, global_step): 62 | scalar_dict = tensor2float(scalar_dict) 63 | for tag, values in scalar_dict.items(): 64 | if not isinstance(values, list) and not isinstance(values, tuple): 65 | values = [values] 66 | for idx, value in enumerate(values): 67 | scalar_name = '{}/{}'.format(mode_tag, tag) 68 | # if len(values) > 1: 69 | scalar_name = scalar_name + "_" + str(idx) 70 | logger.add_scalar(scalar_name, value, global_step) 71 | 72 | 73 | def save_images(logger, mode_tag, images_dict, global_step): 74 | images_dict = tensor2numpy(images_dict) 75 | for tag, values in images_dict.items(): 76 | if not isinstance(values, list) and not isinstance(values, tuple): 77 | values = [values] 78 | for idx, value in enumerate(values): 79 | if len(value.shape) == 3: 80 | value = value[:, np.newaxis, :, :] 81 | value = value[:1] 82 | value = torch.from_numpy(value) 83 | 84 | image_name = '{}/{}'.format(mode_tag, tag) 85 | if len(values) > 1: 86 | image_name = image_name + "_" + str(idx) 87 | logger.add_image(image_name, vutils.make_grid(value, padding=0, nrow=1, normalize=True, scale_each=True), 88 | global_step) 89 | 90 | 91 | def adjust_learning_rate(optimizer, epoch, base_lr, lrepochs): 92 | splits = lrepochs.split(':') 93 | assert len(splits) == 2 94 | 95 | # parse the epochs to downscale the learning rate (before :) 96 | downscale_epochs = [int(eid_str) for eid_str in splits[0].split(',')] 97 | # parse downscale rate (after :) 98 | downscale_rate = float(splits[1]) 99 | print("downscale epochs: {}, downscale rate: {}".format(downscale_epochs, downscale_rate)) 100 | 101 | lr = base_lr 102 | for eid in downscale_epochs: 103 | if epoch >= eid: 104 | lr /= downscale_rate 105 | else: 106 | break 107 | print("setting learning rate to {}".format(lr)) 108 | for param_group in optimizer.param_groups: 109 | param_group['lr'] = lr 110 | 111 | 112 | class AverageMeter(object): 113 | def __init__(self): 114 | self.sum_value = 0. 115 | self.count = 0 116 | 117 | def update(self, x): 118 | check_allfloat(x) 119 | self.sum_value += x 120 | self.count += 1 121 | 122 | def mean(self): 123 | return self.sum_value / self.count 124 | 125 | 126 | class AverageMeterDict(object): 127 | def __init__(self): 128 | self.data = None 129 | self.count = 0 130 | 131 | def update(self, x): 132 | check_allfloat(x) 133 | self.count += 1 134 | if self.data is None: 135 | self.data = copy.deepcopy(x) 136 | else: 137 | for k1, v1 in x.items(): 138 | if isinstance(v1, float): 139 | self.data[k1] += v1 140 | elif isinstance(v1, tuple) or isinstance(v1, list): 141 | for idx, v2 in enumerate(v1): 142 | self.data[k1][idx] += v2 143 | else: 144 | assert NotImplementedError("error input type for update AvgMeterDict") 145 | 146 | def mean(self): 147 | @make_iterative_func 148 | def get_mean(v): 149 | return v / float(self.count) 150 | 151 | return get_mean(self.data) 152 | -------------------------------------------------------------------------------- /KITTI12/utils/experiment.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/utils/experiment.pyc -------------------------------------------------------------------------------- /KITTI12/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from utils.experiment import make_nograd_func 4 | from torch.autograd import Variable 5 | from torch import Tensor 6 | 7 | 8 | # Update D1 from >3px to >=3px & >5% 9 | # matlab code: 10 | # E = abs(D_gt - D_est); 11 | # n_err = length(find(D_gt > 0 & E > tau(1) & E. / abs(D_gt) > tau(2))); 12 | # n_total = length(find(D_gt > 0)); 13 | # d_err = n_err / n_total; 14 | 15 | def check_shape_for_metric_computation(*vars): 16 | assert isinstance(vars, tuple) 17 | for var in vars: 18 | assert len(var.size()) == 3 19 | assert var.size() == vars[0].size() 20 | 21 | # a wrapper to compute metrics for each image individually 22 | def compute_metric_for_each_image(metric_func): 23 | def wrapper(D_ests, D_gts, masks, *nargs): 24 | check_shape_for_metric_computation(D_ests, D_gts, masks) 25 | bn = D_gts.shape[0] # batch size 26 | results = [] # a list to store results for each image 27 | # compute result one by one 28 | for idx in range(bn): 29 | # if tensor, then pick idx, else pass the same value 30 | cur_nargs = [x[idx] if isinstance(x, (Tensor, Variable)) else x for x in nargs] 31 | if masks[idx].float().mean() / (D_gts[idx] > 0).float().mean() < 0.1: 32 | print("masks[idx].float().mean() too small, skip") 33 | else: 34 | ret = metric_func(D_ests[idx], D_gts[idx], masks[idx], *cur_nargs) 35 | results.append(ret) 36 | if len(results) == 0: 37 | print("masks[idx].float().mean() too small for all images in this batch, return 0") 38 | return torch.tensor(0, dtype=torch.float32, device=D_gts.device) 39 | else: 40 | return torch.stack(results).mean() 41 | return wrapper 42 | 43 | @make_nograd_func 44 | @compute_metric_for_each_image 45 | def D1_metric(D_est, D_gt, mask): 46 | D_est, D_gt = D_est[mask], D_gt[mask] 47 | E = torch.abs(D_gt - D_est) 48 | err_mask = (E > 3) & (E / D_gt.abs() > 0.05) 49 | return torch.mean(err_mask.float()) 50 | 51 | @make_nograd_func 52 | @compute_metric_for_each_image 53 | def Thres_metric(D_est, D_gt, mask, thres): 54 | assert isinstance(thres, (int, float)) 55 | D_est, D_gt = D_est[mask], D_gt[mask] 56 | E = torch.abs(D_gt - D_est) 57 | err_mask = E > thres 58 | return torch.mean(err_mask.float()) 59 | 60 | # NOTE: please do not use this to build up training loss 61 | @make_nograd_func 62 | @compute_metric_for_each_image 63 | def EPE_metric(D_est, D_gt, mask): 64 | D_est, D_gt = D_est[mask], D_gt[mask] 65 | return F.l1_loss(D_est, D_gt, size_average=True) 66 | -------------------------------------------------------------------------------- /KITTI12/utils/metrics.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/utils/metrics.pyc -------------------------------------------------------------------------------- /KITTI12/utils/visualization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.data 5 | from torch.autograd import Variable, Function 6 | import torch.nn.functional as F 7 | import math 8 | import numpy as np 9 | 10 | 11 | def gen_error_colormap(): 12 | cols = np.array( 13 | [[0 / 3.0, 0.1875 / 3.0, 49, 54, 149], 14 | [0.1875 / 3.0, 0.375 / 3.0, 69, 117, 180], 15 | [0.375 / 3.0, 0.75 / 3.0, 116, 173, 209], 16 | [0.75 / 3.0, 1.5 / 3.0, 171, 217, 233], 17 | [1.5 / 3.0, 3 / 3.0, 224, 243, 248], 18 | [3 / 3.0, 6 / 3.0, 254, 224, 144], 19 | [6 / 3.0, 12 / 3.0, 253, 174, 97], 20 | [12 / 3.0, 24 / 3.0, 244, 109, 67], 21 | [24 / 3.0, 48 / 3.0, 215, 48, 39], 22 | [48 / 3.0, np.inf, 165, 0, 38]], dtype=np.float32) 23 | cols[:, 2: 5] /= 255. 24 | return cols 25 | 26 | 27 | error_colormap = gen_error_colormap() 28 | 29 | 30 | class disp_error_image_func(Function): 31 | def forward(self, D_est_tensor, D_gt_tensor, abs_thres=3., rel_thres=0.05, dilate_radius=1): 32 | D_gt_np = D_gt_tensor.detach().cpu().numpy() 33 | D_est_np = D_est_tensor.detach().cpu().numpy() 34 | B, H, W = D_gt_np.shape 35 | # valid mask 36 | mask = D_gt_np > 0 37 | # error in percentage. When error <= 1, the pixel is valid since <= 3px & 5% 38 | error = np.abs(D_gt_np - D_est_np) 39 | error[np.logical_not(mask)] = 0 40 | error[mask] = np.minimum(error[mask] / abs_thres, (error[mask] / D_gt_np[mask]) / rel_thres) 41 | # get colormap 42 | cols = error_colormap 43 | # create error image 44 | error_image = np.zeros([B, H, W, 3], dtype=np.float32) 45 | for i in range(cols.shape[0]): 46 | error_image[np.logical_and(error >= cols[i][0], error < cols[i][1])] = cols[i, 2:] 47 | # TODO: imdilate 48 | # error_image = cv2.imdilate(D_err, strel('disk', dilate_radius)); 49 | error_image[np.logical_not(mask)] = 0. 50 | # show color tag in the top-left cornor of the image 51 | for i in range(cols.shape[0]): 52 | distance = 20 53 | error_image[:, :10, i * distance:(i + 1) * distance, :] = cols[i, 2:] 54 | 55 | return torch.from_numpy(np.ascontiguousarray(error_image.transpose([0, 3, 1, 2]))) 56 | 57 | def backward(self, grad_output): 58 | return None 59 | -------------------------------------------------------------------------------- /KITTI12/utils/visualization.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI12/utils/visualization.pyc -------------------------------------------------------------------------------- /KITTI15/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__init__.py -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/extractor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/extractor.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/extractor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/extractor.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/geometry.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/geometry.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/geometry.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/geometry.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/geometry_ddim.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/geometry_ddim.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/geometry_ddim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/geometry_ddim.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/head.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/head.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/head.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/head.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/igev_stereo.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/igev_stereo.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/igev_stereo.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/igev_stereo.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/igev_stereo_ddim.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/igev_stereo_ddim.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/igev_stereo_ddim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/igev_stereo_ddim.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/stereo_datasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/stereo_datasets.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/stereo_datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/stereo_datasets.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/submodule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/submodule.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/submodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/submodule.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/update.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/update.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI15/core/__pycache__/update.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/__pycache__/update.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI15/core/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from core.utils.utils import bilinear_sampler 4 | 5 | 6 | class Combined_Geo_Encoding_Volume: 7 | def __init__(self, init_fmap1, init_fmap2, geo_volume, num_levels=2, radius=4): 8 | self.num_levels = num_levels 9 | self.radius = radius 10 | self.geo_volume_pyramid = [] 11 | self.init_corr_pyramid = [] 12 | 13 | # all pairs correlation 14 | init_corr = Combined_Geo_Encoding_Volume.corr(init_fmap1, init_fmap2) 15 | 16 | b, h, w, _, w2 = init_corr.shape 17 | b, c, d, h, w = geo_volume.shape 18 | geo_volume = geo_volume.permute(0, 3, 4, 1, 2).reshape(b*h*w, c, 1, d) 19 | 20 | init_corr = init_corr.reshape(b*h*w, 1, 1, w2) 21 | self.geo_volume_pyramid.append(geo_volume) 22 | self.init_corr_pyramid.append(init_corr) 23 | for i in range(self.num_levels-1): 24 | geo_volume = F.avg_pool2d(geo_volume, [1,2], stride=[1,2]) 25 | self.geo_volume_pyramid.append(geo_volume) 26 | 27 | for i in range(self.num_levels-1): 28 | init_corr = F.avg_pool2d(init_corr, [1,2], stride=[1,2]) 29 | self.init_corr_pyramid.append(init_corr) 30 | 31 | 32 | 33 | 34 | def __call__(self, disp, coords): 35 | r = self.radius 36 | b, _, h, w = disp.shape 37 | out_pyramid = [] 38 | for i in range(self.num_levels): 39 | geo_volume = self.geo_volume_pyramid[i] 40 | dx = torch.linspace(-r, r, 2*r+1) 41 | dx = dx.view(1, 1, 2*r+1, 1).to(disp.device) 42 | x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i 43 | y0 = torch.zeros_like(x0) 44 | 45 | disp_lvl = torch.cat([x0,y0], dim=-1) 46 | geo_volume = bilinear_sampler(geo_volume, disp_lvl) 47 | geo_volume = geo_volume.view(b, h, w, -1) 48 | 49 | init_corr = self.init_corr_pyramid[i] 50 | init_x0 = coords.reshape(b*h*w, 1, 1, 1)/2**i - disp.reshape(b*h*w, 1, 1, 1) / 2**i + dx 51 | init_coords_lvl = torch.cat([init_x0,y0], dim=-1) 52 | init_corr = bilinear_sampler(init_corr, init_coords_lvl) 53 | init_corr = init_corr.view(b, h, w, -1) 54 | 55 | out_pyramid.append(geo_volume) 56 | out_pyramid.append(init_corr) 57 | out = torch.cat(out_pyramid, dim=-1) 58 | return out.permute(0, 3, 1, 2).contiguous().float() 59 | 60 | 61 | @staticmethod 62 | def corr(fmap1, fmap2): 63 | B, D, H, W1 = fmap1.shape 64 | _, _, _, W2 = fmap2.shape 65 | fmap1 = fmap1.view(B, D, H, W1) 66 | fmap2 = fmap2.view(B, D, H, W2) 67 | corr = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2) 68 | corr = corr.reshape(B, H, W1, 1, W2).contiguous() 69 | return corr -------------------------------------------------------------------------------- /KITTI15/core/geometry_ddim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from core.utils.utils import bilinear_sampler 4 | 5 | 6 | class Combined_Geo_Encoding_Volume: 7 | def __init__(self, init_fmap1, init_fmap2, geo_volume, num_levels=2, radius=4): 8 | self.num_levels = num_levels 9 | self.radius = radius 10 | self.geo_volume_pyramid = [] 11 | self.init_corr_pyramid = [] 12 | 13 | # all pairs correlation 14 | init_corr = Combined_Geo_Encoding_Volume.corr(init_fmap1, init_fmap2) 15 | 16 | b, h, w, _, w2 = init_corr.shape 17 | b, c, d, h, w = geo_volume.shape 18 | self.channel = c 19 | geo_volume = geo_volume.permute(0, 3, 4, 1, 2).reshape(b*h*w, c, 1, d) 20 | 21 | init_corr = init_corr.reshape(b*h*w, 1, 1, w2) 22 | self.geo_volume_pyramid.append(geo_volume) 23 | self.init_corr_pyramid.append(init_corr) 24 | for i in range(self.num_levels-1): 25 | geo_volume = F.avg_pool2d(geo_volume, [1,2], stride=[1,2]) 26 | self.geo_volume_pyramid.append(geo_volume) 27 | 28 | for i in range(self.num_levels-1): 29 | init_corr = F.avg_pool2d(init_corr, [1,2], stride=[1,2]) 30 | self.init_corr_pyramid.append(init_corr) 31 | 32 | 33 | def __call__(self, disp, coords, noisy): 34 | r = self.radius 35 | b, _, h, w = disp.shape 36 | batch, _, h1, w1 = coords.shape 37 | noisy = noisy.reshape(batch*h1*w1, 1, 1, -1) 38 | 39 | noise = [] 40 | noise.append(noisy) 41 | for i in range(self.num_levels): 42 | noisy = F.avg_pool2d(noisy, [1, 2], stride=[1, 2]) 43 | noise.append(noisy) 44 | 45 | out_pyramid = [] 46 | for i in range(self.num_levels): 47 | geo_volume = self.geo_volume_pyramid[i] 48 | noi = noise[i] 49 | dx = torch.linspace(-r, r, 2*r+1) 50 | dx = dx.view(1, 1, 2*r+1, 1).to(disp.device) 51 | x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i 52 | y0 = torch.zeros_like(x0) 53 | 54 | disp_lvl = torch.cat([x0,y0], dim=-1) 55 | 56 | geo_volume = geo_volume * noi 57 | geo_volume = bilinear_sampler(geo_volume, disp_lvl) 58 | geo_volume = geo_volume.view(b, h, w, -1) 59 | 60 | init_corr = self.init_corr_pyramid[i] 61 | init_x0 = coords.reshape(b*h*w, 1, 1, 1)/2**i - disp.reshape(b*h*w, 1, 1, 1) / 2**i + dx 62 | init_coords_lvl = torch.cat([init_x0,y0], dim=-1) 63 | init_corr = bilinear_sampler(init_corr, init_coords_lvl) 64 | init_corr = init_corr.view(b, h, w, -1) 65 | 66 | out_pyramid.append(geo_volume) 67 | out_pyramid.append(init_corr) 68 | out = torch.cat(out_pyramid, dim=-1) 69 | return out.permute(0, 3, 1, 2).contiguous().float() 70 | 71 | 72 | @staticmethod 73 | def corr(fmap1, fmap2): 74 | B, D, H, W1 = fmap1.shape 75 | _, _, _, W2 = fmap2.shape 76 | fmap1 = fmap1.view(B, D, H, W1) 77 | fmap2 = fmap2.view(B, D, H, W2) 78 | corr = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2) 79 | corr = corr.reshape(B, H, W1, 1, W2).contiguous() 80 | return corr -------------------------------------------------------------------------------- /KITTI15/core/head.py: -------------------------------------------------------------------------------- 1 | """ 2 | DiffusionDet Transformer class. 3 | 4 | Copy-paste from torch.nn.Transformer with modifications: 5 | * positional encodings are passed in MHattention 6 | * extra LN at the end of encoder is removed 7 | * decoder returns a stack of activations from all decoding layers 8 | """ 9 | import copy 10 | import math 11 | 12 | import numpy as np 13 | import torch 14 | from torch import nn, Tensor 15 | import torch.nn.functional as F 16 | 17 | 18 | 19 | _DEFAULT_SCALE_CLAMP = math.log(100000.0 / 16) 20 | 21 | 22 | class SinusoidalPositionEmbeddings(nn.Module): 23 | def __init__(self, dim): 24 | super().__init__() 25 | self.dim = dim 26 | 27 | def forward(self, time): 28 | device = time.device 29 | half_dim = self.dim // 2 30 | embeddings = math.log(10000) / (half_dim - 1) 31 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 32 | embeddings = time[:, None] * embeddings[None, :] 33 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 34 | return embeddings 35 | 36 | 37 | class GaussianFourierProjection(nn.Module): 38 | """Gaussian random features for encoding time steps.""" 39 | 40 | def __init__(self, embed_dim, scale=30.): 41 | super().__init__() 42 | # Randomly sample weights during initialization. These weights are fixed 43 | # during optimization and are not trainable. 44 | self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False) 45 | 46 | def forward(self, x): 47 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 48 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 49 | 50 | 51 | class DynamicHead(nn.Module): 52 | 53 | def __init__(self, d_model): 54 | super().__init__() 55 | self.d_model = d_model 56 | time_dim = d_model * 4 57 | self.time_mlp = nn.Sequential( 58 | SinusoidalPositionEmbeddings(d_model), 59 | nn.Linear(d_model, time_dim), 60 | nn.GELU(), 61 | nn.Linear(time_dim, time_dim), 62 | ) 63 | self.block_time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(d_model * 4, d_model)) 64 | #self.block_time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(d_model * 4, d_model), nn.Sigmoid()) 65 | 66 | self._reset_parameters() 67 | 68 | def _reset_parameters(self): 69 | # init all parameters. 70 | for p in self.parameters(): 71 | if p.dim() > 1: 72 | nn.init.xavier_uniform_(p) 73 | 74 | def forward(self, noisy, t): 75 | time_emb = self.time_mlp(t) 76 | scale_shift = self.block_time_mlp(time_emb)#.unsqueeze(-1).unsqueeze(-1) 77 | b, d, h, w = noisy.shape 78 | scale_shift_z = F.interpolate(scale_shift.unsqueeze(0), (d), mode="linear").squeeze(1).unsqueeze(-1).unsqueeze(-1) 79 | 80 | # print(noisy.shape) 81 | # print(scale_shift.shape) 82 | # raise 83 | noisy = noisy + scale_shift_z 84 | #noisy = noisy * scale_shift 85 | # scale, shift = scale_shift.chunk(2, dim=1) 86 | # volume = volume * (scale + 1) + shift 87 | 88 | return noisy -------------------------------------------------------------------------------- /KITTI15/core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from opt_einsum import contract 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256, output_dim=2): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class DispHead(nn.Module): 17 | def __init__(self, input_dim=128, hidden_dim=256, output_dim=1): 18 | super(DispHead, self).__init__() 19 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 20 | self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1) 21 | self.relu = nn.ReLU(inplace=True) 22 | 23 | def forward(self, x): 24 | return self.conv2(self.relu(self.conv1(x))) 25 | 26 | class ConvGRU(nn.Module): 27 | def __init__(self, hidden_dim, input_dim, kernel_size=3): 28 | super(ConvGRU, self).__init__() 29 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2) 30 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2) 31 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2) 32 | 33 | def forward(self, h, cz, cr, cq, *x_list): 34 | 35 | x = torch.cat(x_list, dim=1) 36 | hx = torch.cat([h, x], dim=1) 37 | z = torch.sigmoid(self.convz(hx) + cz) 38 | r = torch.sigmoid(self.convr(hx) + cr) 39 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)) + cq) 40 | h = (1-z) * h + z * q 41 | return h 42 | 43 | class SepConvGRU(nn.Module): 44 | def __init__(self, hidden_dim=128, input_dim=192+128): 45 | super(SepConvGRU, self).__init__() 46 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 47 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 48 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 49 | 50 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 51 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 52 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 53 | 54 | 55 | def forward(self, h, *x): 56 | # horizontal 57 | x = torch.cat(x, dim=1) 58 | hx = torch.cat([h, x], dim=1) 59 | z = torch.sigmoid(self.convz1(hx)) 60 | r = torch.sigmoid(self.convr1(hx)) 61 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 62 | h = (1-z) * h + z * q 63 | 64 | # vertical 65 | hx = torch.cat([h, x], dim=1) 66 | z = torch.sigmoid(self.convz2(hx)) 67 | r = torch.sigmoid(self.convr2(hx)) 68 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 69 | h = (1-z) * h + z * q 70 | 71 | return h 72 | 73 | class BasicMotionEncoder(nn.Module): 74 | def __init__(self, args): 75 | super(BasicMotionEncoder, self).__init__() 76 | self.args = args 77 | cor_planes = args.corr_levels * (2*args.corr_radius + 1) * (8+1) 78 | self.convc1 = nn.Conv2d(cor_planes, 64, 1, padding=0) 79 | self.convc2 = nn.Conv2d(64, 64, 3, padding=1) 80 | self.convd1 = nn.Conv2d(1, 64, 7, padding=3) 81 | self.convd2 = nn.Conv2d(64, 64, 3, padding=1) 82 | self.conv = nn.Conv2d(64+64, 128-1, 3, padding=1) 83 | 84 | def forward(self, disp, corr): 85 | cor = F.relu(self.convc1(corr)) 86 | cor = F.relu(self.convc2(cor)) 87 | disp_ = F.relu(self.convd1(disp)) 88 | disp_ = F.relu(self.convd2(disp_)) 89 | 90 | cor_disp = torch.cat([cor, disp_], dim=1) 91 | out = F.relu(self.conv(cor_disp)) 92 | return torch.cat([out, disp], dim=1) 93 | 94 | def pool2x(x): 95 | return F.avg_pool2d(x, 3, stride=2, padding=1) 96 | 97 | def pool4x(x): 98 | return F.avg_pool2d(x, 5, stride=4, padding=1) 99 | 100 | def interp(x, dest): 101 | interp_args = {'mode': 'bilinear', 'align_corners': True} 102 | return F.interpolate(x, dest.shape[2:], **interp_args) 103 | 104 | class BasicMultiUpdateBlock(nn.Module): 105 | def __init__(self, args, hidden_dims=[]): 106 | super().__init__() 107 | self.args = args 108 | self.encoder = BasicMotionEncoder(args) 109 | encoder_output_dim = 128 110 | 111 | self.gru04 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (args.n_gru_layers > 1)) 112 | self.gru08 = ConvGRU(hidden_dims[1], hidden_dims[0] * (args.n_gru_layers == 3) + hidden_dims[2]) 113 | self.gru16 = ConvGRU(hidden_dims[0], hidden_dims[1]) 114 | self.disp_head = DispHead(hidden_dims[2], hidden_dim=256, output_dim=1) 115 | factor = 2**self.args.n_downsample 116 | 117 | self.mask_feat_4 = nn.Sequential( 118 | nn.Conv2d(hidden_dims[2], 32, 3, padding=1), 119 | nn.ReLU(inplace=True)) 120 | 121 | def forward(self, net, inp, corr=None, disp=None, iter04=True, iter08=True, iter16=True, update=True): 122 | 123 | if iter16: 124 | net[2] = self.gru16(net[2], *(inp[2]), pool2x(net[1])) 125 | if iter08: 126 | if self.args.n_gru_layers > 2: 127 | net[1] = self.gru08(net[1], *(inp[1]), pool2x(net[0]), interp(net[2], net[1])) 128 | else: 129 | net[1] = self.gru08(net[1], *(inp[1]), pool2x(net[0])) 130 | if iter04: 131 | motion_features = self.encoder(disp, corr) 132 | if self.args.n_gru_layers > 1: 133 | net[0] = self.gru04(net[0], *(inp[0]), motion_features, interp(net[1], net[0])) 134 | else: 135 | net[0] = self.gru04(net[0], *(inp[0]), motion_features) 136 | 137 | if not update: 138 | return net 139 | 140 | delta_disp = self.disp_head(net[0]) 141 | mask_feat_4 = self.mask_feat_4(net[0]) 142 | return net, mask_feat_4, delta_disp 143 | -------------------------------------------------------------------------------- /KITTI15/core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/utils/__init__.py -------------------------------------------------------------------------------- /KITTI15/core/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI15/core/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI15/core/utils/__pycache__/augmentor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/utils/__pycache__/augmentor.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI15/core/utils/__pycache__/augmentor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/utils/__pycache__/augmentor.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI15/core/utils/__pycache__/frame_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/utils/__pycache__/frame_utils.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI15/core/utils/__pycache__/frame_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/utils/__pycache__/frame_utils.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI15/core/utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /KITTI15/core/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/KITTI15/core/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /KITTI15/core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | import json 6 | import imageio 7 | import cv2 8 | cv2.setNumThreads(0) 9 | cv2.ocl.setUseOpenCL(False) 10 | 11 | TAG_CHAR = np.array([202021.25], np.float32) 12 | 13 | def readFlow(fn): 14 | """ Read .flo file in Middlebury format""" 15 | # Code adapted from: 16 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 17 | 18 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 19 | # print 'fn = %s'%(fn) 20 | with open(fn, 'rb') as f: 21 | magic = np.fromfile(f, np.float32, count=1) 22 | if 202021.25 != magic: 23 | print('Magic number incorrect. Invalid .flo file') 24 | return None 25 | else: 26 | w = np.fromfile(f, np.int32, count=1) 27 | h = np.fromfile(f, np.int32, count=1) 28 | # print 'Reading %d x %d flo file\n' % (w, h) 29 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 30 | # Reshape data into 3D array (columns, rows, bands) 31 | # The reshape here is for visualization, the original code is (w,h,2) 32 | return np.resize(data, (int(h), int(w), 2)) 33 | 34 | def readPFM(file): 35 | file = open(file, 'rb') 36 | 37 | color = None 38 | width = None 39 | height = None 40 | scale = None 41 | endian = None 42 | 43 | header = file.readline().rstrip() 44 | if header == b'PF': 45 | color = True 46 | elif header == b'Pf': 47 | color = False 48 | else: 49 | raise Exception('Not a PFM file.') 50 | 51 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 52 | if dim_match: 53 | width, height = map(int, dim_match.groups()) 54 | else: 55 | raise Exception('Malformed PFM header.') 56 | 57 | scale = float(file.readline().rstrip()) 58 | if scale < 0: # little-endian 59 | endian = '<' 60 | scale = -scale 61 | else: 62 | endian = '>' # big-endian 63 | 64 | data = np.fromfile(file, endian + 'f') 65 | shape = (height, width, 3) if color else (height, width) 66 | 67 | data = np.reshape(data, shape) 68 | data = np.flipud(data) 69 | return data 70 | 71 | def writePFM(file, array): 72 | import os 73 | assert type(file) is str and type(array) is np.ndarray and \ 74 | os.path.splitext(file)[1] == ".pfm" 75 | with open(file, 'wb') as f: 76 | H, W = array.shape 77 | headers = ["Pf\n", f"{W} {H}\n", "-1\n"] 78 | for header in headers: 79 | f.write(str.encode(header)) 80 | array = np.flip(array, axis=0).astype(np.float32) 81 | f.write(array.tobytes()) 82 | 83 | 84 | 85 | def writeFlow(filename,uv,v=None): 86 | """ Write optical flow to file. 87 | 88 | If v is None, uv is assumed to contain both u and v channels, 89 | stacked in depth. 90 | Original code by Deqing Sun, adapted from Daniel Scharstein. 91 | """ 92 | nBands = 2 93 | 94 | if v is None: 95 | assert(uv.ndim == 3) 96 | assert(uv.shape[2] == 2) 97 | u = uv[:,:,0] 98 | v = uv[:,:,1] 99 | else: 100 | u = uv 101 | 102 | assert(u.shape == v.shape) 103 | height,width = u.shape 104 | f = open(filename,'wb') 105 | # write the header 106 | f.write(TAG_CHAR) 107 | np.array(width).astype(np.int32).tofile(f) 108 | np.array(height).astype(np.int32).tofile(f) 109 | # arrange into matrix form 110 | tmp = np.zeros((height, width*nBands)) 111 | tmp[:,np.arange(width)*2] = u 112 | tmp[:,np.arange(width)*2 + 1] = v 113 | tmp.astype(np.float32).tofile(f) 114 | f.close() 115 | 116 | 117 | def readFlowKITTI(filename): 118 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 119 | flow = flow[:,:,::-1].astype(np.float32) 120 | flow, valid = flow[:, :, :2], flow[:, :, 2] 121 | flow = (flow - 2**15) / 64.0 122 | return flow, valid 123 | 124 | def readDispKITTI(filename): 125 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 126 | valid = disp > 0.0 127 | return disp, valid 128 | 129 | # Method taken from /n/fs/raft-depth/RAFT-Stereo/datasets/SintelStereo/sdk/python/sintel_io.py 130 | def readDispSintelStereo(file_name): 131 | a = np.array(Image.open(file_name)) 132 | d_r, d_g, d_b = np.split(a, axis=2, indices_or_sections=3) 133 | disp = (d_r * 4 + d_g / (2**6) + d_b / (2**14))[..., 0] 134 | mask = np.array(Image.open(file_name.replace('disparities', 'occlusions'))) 135 | valid = ((mask == 0) & (disp > 0)) 136 | return disp, valid 137 | 138 | # Method taken from https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt 139 | def readDispFallingThings(file_name): 140 | a = np.array(Image.open(file_name)) 141 | with open('/'.join(file_name.split('/')[:-1] + ['_camera_settings.json']), 'r') as f: 142 | intrinsics = json.load(f) 143 | fx = intrinsics['camera_settings'][0]['intrinsic_settings']['fx'] 144 | disp = (fx * 6.0 * 100) / a.astype(np.float32) 145 | valid = disp > 0 146 | return disp, valid 147 | 148 | # Method taken from https://github.com/castacks/tartanair_tools/blob/master/data_type.md 149 | def readDispTartanAir(file_name): 150 | depth = np.load(file_name) 151 | disp = 80.0 / depth 152 | valid = disp > 0 153 | return disp, valid 154 | 155 | 156 | def readDispMiddlebury(file_name): 157 | assert basename(file_name) == 'disp0GT.pfm' 158 | disp = readPFM(file_name).astype(np.float32) 159 | assert len(disp.shape) == 2 160 | nocc_pix = file_name.replace('disp0GT.pfm', 'mask0nocc.png') 161 | assert exists(nocc_pix) 162 | nocc_pix = imageio.imread(nocc_pix) == 255 163 | assert np.any(nocc_pix) 164 | return disp, nocc_pix 165 | 166 | def writeFlowKITTI(filename, uv): 167 | uv = 64.0 * uv + 2**15 168 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 169 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 170 | cv2.imwrite(filename, uv[..., ::-1]) 171 | 172 | 173 | def read_gen(file_name, pil=False): 174 | ext = splitext(file_name)[-1] 175 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 176 | return Image.open(file_name) 177 | elif ext == '.bin' or ext == '.raw': 178 | return np.load(file_name) 179 | elif ext == '.flo': 180 | return readFlow(file_name).astype(np.float32) 181 | elif ext == '.pfm': 182 | flow = readPFM(file_name).astype(np.float32) 183 | if len(flow.shape) == 2: 184 | return flow 185 | else: 186 | return flow[:, :, :-1] 187 | return [] -------------------------------------------------------------------------------- /KITTI15/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel', divis_by=8): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by 12 | pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | assert all((x.ndim == 4) for x in inputs) 20 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 21 | 22 | def unpad(self, x): 23 | assert x.ndim == 4 24 | ht, wd = x.shape[-2:] 25 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 26 | return x[..., c[0]:c[1], c[2]:c[3]] 27 | 28 | def forward_interpolate(flow): 29 | flow = flow.detach().cpu().numpy() 30 | dx, dy = flow[0], flow[1] 31 | 32 | ht, wd = dx.shape 33 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 34 | 35 | x1 = x0 + dx 36 | y1 = y0 + dy 37 | 38 | x1 = x1.reshape(-1) 39 | y1 = y1.reshape(-1) 40 | dx = dx.reshape(-1) 41 | dy = dy.reshape(-1) 42 | 43 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 44 | x1 = x1[valid] 45 | y1 = y1[valid] 46 | dx = dx[valid] 47 | dy = dy[valid] 48 | 49 | flow_x = interpolate.griddata( 50 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 51 | 52 | flow_y = interpolate.griddata( 53 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 54 | 55 | flow = np.stack([flow_x, flow_y], axis=0) 56 | return torch.from_numpy(flow).float() 57 | 58 | 59 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 60 | """ Wrapper for grid_sample, uses pixel coordinates """ 61 | H, W = img.shape[-2:] 62 | 63 | # print("$$$55555", img.shape, coords.shape) 64 | xgrid, ygrid = coords.split([1,1], dim=-1) 65 | xgrid = 2*xgrid/(W-1) - 1 66 | 67 | # print("######88888", xgrid) 68 | assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem 69 | 70 | grid = torch.cat([xgrid, ygrid], dim=-1) 71 | # print("###37777", grid.shape) 72 | img = F.grid_sample(img, grid, align_corners=True) 73 | if mask: 74 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 75 | return img, mask.float() 76 | 77 | return img 78 | 79 | 80 | def coords_grid(batch, ht, wd): 81 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 82 | coords = torch.stack(coords[::-1], dim=0).float() 83 | return coords[None].repeat(batch, 1, 1, 1) 84 | 85 | 86 | def upflow8(flow, mode='bilinear'): 87 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 88 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 89 | 90 | def gauss_blur(input, N=5, std=1): 91 | B, D, H, W = input.shape 92 | x, y = torch.meshgrid(torch.arange(N).float() - N//2, torch.arange(N).float() - N//2) 93 | unnormalized_gaussian = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * std ** 2)) 94 | weights = unnormalized_gaussian / unnormalized_gaussian.sum().clamp(min=1e-4) 95 | weights = weights.view(1,1,N,N).to(input) 96 | output = F.conv2d(input.reshape(B*D,1,H,W), weights, padding=N//2) 97 | return output.view(B, D, H, W) -------------------------------------------------------------------------------- /KITTI15/run.sh: -------------------------------------------------------------------------------- 1 | #train 2 | # python train_stereo.py --logdir ./checkpoints/kitti --restore_ckpt ./pretrained_models/kitti/kitti15.pth --train_datasets kitti 3 | #test 4 | python evaluate_stereo.py --restore_ckpt /home/zhengdian/code/DiffuVolume_github/KITTI15_IGEV/checkpoints/10000_igev-stereo.pth --dataset kitti -------------------------------------------------------------------------------- /KITTI15/save_disp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | import argparse 5 | import glob 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | from pathlib import Path 10 | from core.igev_stereo import IGEVStereo, autocast 11 | from core.igev_stereo_ddim import IGEVStereo_ddim 12 | from utils.utils import InputPadder 13 | import torch.nn.functional as F 14 | from PIL import Image 15 | from matplotlib import pyplot as plt 16 | import os 17 | import skimage.io 18 | import cv2 19 | 20 | 21 | DEVICE = 'cuda' 22 | 23 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 24 | 25 | def load_image(imfile): 26 | img = np.array(Image.open(imfile)).astype(np.uint8) 27 | img = torch.from_numpy(img).permute(2, 0, 1).float() 28 | return img[None].to(DEVICE) 29 | 30 | def demo(args): 31 | model_origin = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0]) 32 | model_origin.load_state_dict(torch.load(args.pretrained_ckpt)) 33 | 34 | model_origin = model_origin.module 35 | model_origin.to(DEVICE) 36 | model_origin.eval() 37 | 38 | model = torch.nn.DataParallel(IGEVStereo_ddim(args), device_ids=[0]) 39 | model.load_state_dict(torch.load(args.restore_ckpt)) 40 | 41 | model = model.module 42 | model.to(DEVICE) 43 | model.eval() 44 | 45 | output_directory = Path(args.output_directory) 46 | output_directory.mkdir(exist_ok=True) 47 | 48 | with torch.no_grad(): 49 | left_images = sorted(glob.glob(args.left_imgs, recursive=True)) 50 | right_images = sorted(glob.glob(args.right_imgs, recursive=True)) 51 | print(f"Found {len(left_images)} images. Saving files to {output_directory}/") 52 | 53 | for (imfile1, imfile2) in tqdm(list(zip(left_images, right_images))): 54 | image1 = load_image(imfile1) 55 | image2 = load_image(imfile2) 56 | padder = InputPadder(image1.shape, divis_by=32) 57 | image1, image2 = padder.pad(image1, image2) 58 | mixed_prec=False 59 | iters=32 60 | with autocast(enabled=mixed_prec): 61 | flow_pr = model_origin(image1, image2, iters=iters, test_mode=True) 62 | 63 | b, c, h, w = image1.shape 64 | flow_ori = torch.clamp(flow_pr, 0, w-1) 65 | flow_4 = F.interpolate(flow_ori, size=(h // 4, w // 4), mode='bilinear') / 4 66 | 67 | with autocast(enabled=mixed_prec): 68 | _, disp = model(image1, image2, flow_pr, flow_4, iters=iters, test_mode=True) 69 | disp = padder.unpad(disp.unsqueeze(1)).cpu().squeeze(0) 70 | file_stem = os.path.join(output_directory, imfile1.split('/')[-1]) 71 | disp = disp.cpu().numpy().squeeze() 72 | disp = np.round(disp * 256).astype(np.uint16) 73 | skimage.io.imsave(file_stem, disp) 74 | 75 | 76 | if __name__ == '__main__': 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('--pretrained_ckpt', help="restore checkpoint", default='./pretrained_models/kitti/kitti15.pth') 79 | parser.add_argument('--restore_ckpt', help="restore checkpoint", default='./checkpoints/10000_igev-stereo.pth') 80 | parser.add_argument('--save_numpy', action='store_true', help='save output as numpy arrays') 81 | parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="/mnt/Datasets/KITTI/2015/testing/image_2/*_10.png") 82 | parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="/mnt/Datasets/KITTI/2015/testing/image_3/*_10.png") 83 | parser.add_argument('--output_directory', help="directory to save output", default="output") 84 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 85 | parser.add_argument('--valid_iters', type=int, default=16, help='number of flow-field updates during forward pass') 86 | 87 | # Architecture choices 88 | parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions") 89 | parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation") 90 | parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders") 91 | parser.add_argument('--corr_levels', type=int, default=2, help="number of levels in the correlation pyramid") 92 | parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid") 93 | parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)") 94 | parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently") 95 | parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels") 96 | parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume") 97 | 98 | args = parser.parse_args() 99 | 100 | demo(args) 101 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 iSEE 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## DiffuVolume: Diffusion Model for Volume based Stereo Matching
Official PyTorch Implementation of DiffuVolume. 2 | 3 | [Paper](https://arxiv.org/pdf/2308.15989.pdf) | [Personal HomePage](https://zhengdian1.github.io) 4 | 5 | ### Updates 6 | [**2025.01.15**] 🎉🎉🎉 DiffuVolume is finally accepted by IJCV2025 after a long wait! 🎉🎉🎉
7 | [**2024.05.06**] We refine our code for better user experience
8 | [**2024.03.17**] The **pretrained weights** of DiffuVolume are released in [link1](https://drive.google.com/drive/folders/1aCmW6-MBBkvJ4pQ3_AchxzzrezHmArEp?usp=drive_link)
9 | [**2024.03.16**] The **whole training and testing codes** are released!!!
10 | [**2023.08.31**] Our DiffuVolume paper is submitted to IJCV
11 | 12 | ## Introduction 13 | 14 | Cost Volume-based stereo matching methods need to build a redundant cost volume, which interferes with the model training and limitting the performance. In this work, we build a volume filter based on diffusion model, named DiffuVolume, which only uses the diffusion algorithm but not the heavy U-Net network to iteratively remove the redundant information in the cost volume. By adding the DiffuVolume into well-performed methods, we outperform all the published volume-based methods on Scene Flow, KITTI and zero-shot benchmarks. 15 | 16 | ### Training Framework 17 | ![image](Images/diffuvolume.png) 18 | ### Inference Framework 19 | ![image](Images/infer.png) 20 | 21 | # How to use 22 | 23 | ## Environment 24 | * Python 3.8 25 | * Pytorch 2.0 26 | 27 | ## Install 28 | 29 | ### Create a virtual environment and activate it. 30 | 31 | ``` 32 | conda create -n diffuvolume python=3.8 33 | conda activate diffuvolume 34 | ``` 35 | ### Dependencies 36 | 37 | ``` 38 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch -c nvidia 39 | pip install opencv-python 40 | pip install scikit-image 41 | pip install tensorboard 42 | pip install matplotlib 43 | pip install tqdm 44 | ``` 45 | 46 | ## Data Preparation 47 | Download [Scene Flow Datasets](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html), [KITTI 2012](http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php?benchmark=stereo), [KITTI 2015](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo) 48 | 49 | ## Train 50 | Our DiffuVolume is a plug-and-play module for existing volume-based methods. Here we show the code trained on Scene Flow, KITTI2012, and KITTI2015 51 | 52 | Scene Flow (using pretrained model on ACVNet) 53 | ``` 54 | cd SceneFlow 55 | python main.py 56 | ``` 57 | 58 | KITTI2012 (using pretrained model on PCWNet) 59 | ``` 60 | cd KITTI12 61 | python main.py 62 | ``` 63 | 64 | KITTI2015 (using pretrained model on IGEV-Stereo) 65 | ``` 66 | cd KITTI15 67 | sh run.sh 68 | ``` 69 | 70 | ## Test and Visualize 71 | Scene Flow 72 | ``` 73 | cd SceneFlow 74 | python test_sceneflow_ddim.py 75 | python save_disp_sceneflow.py 76 | ``` 77 | 78 | KITTI2012 79 | ``` 80 | cd KITTI12 81 | python test.py 82 | python save_disp_sceneflow_kitti12.py 83 | ``` 84 | 85 | KITTI2015 86 | ``` 87 | cd KITTI15 88 | sh run.sh 89 | python save_disp.py 90 | ``` 91 | 92 | 93 | ## Results on KITTI 2015 leaderboard 94 | [Leaderboard Link 2015](https://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo&eval_gt=noc&eval_area=all) 95 | 96 | | Method | D1-bg (All) | D1-fg (All) | D1-all (All) | Runtime (s) | 97 | |:-:|:-:|:-:|:-:|:-:| 98 | | DiffuVolume | 1.35 % | 2.51 % | 1.54 % | 0.18 | 99 | | IGEV | 1.38 % | 2.67 % | 1.59 % | 0.18 | 100 | | ACVNet | 1.37 % | 3.07 % | 1.65 % | 0.20 | 101 | | GwcNet | 1.74 % | 3.93 % | 2.11 % | 0.32 | 102 | | PSMNet | 1.86 % | 4.62 % | 2.32 % | 0.41 | 103 | 104 | ## Comparison with traditional diffusion based stereo matching 105 | 106 | | Method | EPE (px) | Bad1.0 | Runtime (s) | Params (M) | 107 | |:-:|:-:|:-:|:-:|:-:| 108 | | DiffuVolume | 0.46 | 4.97 % | 1.11 | 7.23 | 109 | | DDPM | 0.59 | 6.06 % | 265 | 60.07 | 110 | | DDIM | 0.63 | 6.13 % | 1.21 | 60.07 | 111 | 112 | ## Qualitative results on ETH3D and Middlebury 113 | 114 | ### We show the zero-shot generalization results of our DiffuVolume compared with current SOTA methods IGEV. 115 | 116 | ![image](Images/zero.png) 117 | 118 | # Citation 119 | 120 | If you find this project helpful in your research, welcome to cite the paper. 121 | 122 | ``` 123 | @article{zheng2023diffuvolume, 124 | title={DiffuVolume: Diffusion Model for Volume based Stereo Matching}, 125 | author={Zheng, Dian and Wu, Xiao-Ming and Liu, Zuhao and Meng, Jingke and Zheng, Wei-shi}, 126 | journal={arXiv preprint arXiv:2308.15989}, 127 | year={2023} 128 | } 129 | 130 | ``` 131 | 132 | # Acknowledgements 133 | 134 | Thanks to Gangwei Xu for opening source of his excellent works ACVNet and IGEV-Stereo. Our work is inspired by these works and part of codes are migrated from [ACVNet](https://github.com/gangweiX/ACVNet), [IGEV](https://github.com/gangweiX/IGEV).
135 | Thanks to Zhelun Shen for opening source of his excellent works PCWNet. Our work is inspired by this work and part of codes are migrated from [PCWNet](https://github.com/gallenszl/PCWNet). 136 | 137 | # Contact 138 | 139 | Please contact Dian Zheng if there are any questions (1423606603@qq.com or zhengd35@mail2.sysu.edu.cn). 140 | -------------------------------------------------------------------------------- /SceneFlow/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 gangweiX 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /SceneFlow/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .kitti_dataset import KITTIDataset 2 | from .kitti_dataset_1215 import KITTIDataset1215 3 | from .sceneflow_dataset import SceneFlowDatset 4 | 5 | __datasets__ = { 6 | "sceneflow": SceneFlowDatset, 7 | "kitti": KITTIDataset, 8 | "kitti1215": KITTIDataset1215 9 | } 10 | -------------------------------------------------------------------------------- /SceneFlow/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/datasets/__pycache__/data_io.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/datasets/__pycache__/data_io.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/datasets/__pycache__/flow_transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/datasets/__pycache__/flow_transforms.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/datasets/__pycache__/kitti_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/datasets/__pycache__/kitti_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/datasets/__pycache__/kitti_dataset_1215.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/datasets/__pycache__/kitti_dataset_1215.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/datasets/__pycache__/sceneflow_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/datasets/__pycache__/sceneflow_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/datasets/data_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import torchvision.transforms as transforms 4 | 5 | 6 | def get_transform(): 7 | mean = [0.485, 0.456, 0.406] 8 | std = [0.229, 0.224, 0.225] 9 | 10 | return transforms.Compose([ 11 | transforms.ToTensor(), 12 | transforms.Normalize(mean=mean, std=std), 13 | ]) 14 | 15 | def get_transform_aug(): 16 | mean = [0.485, 0.456, 0.406] 17 | std = [0.229, 0.224, 0.225] 18 | 19 | return transforms.Compose([ 20 | transforms.ToTensor(), 21 | ]) 22 | 23 | 24 | # read all lines in a file 25 | def read_all_lines(filename): 26 | with open(filename) as f: 27 | lines = [line.rstrip() for line in f.readlines()] 28 | return lines 29 | 30 | 31 | # read an .pfm file into numpy array, used to load SceneFlow disparity files 32 | def pfm_imread(filename): 33 | file = open(filename, 'rb') 34 | color = None 35 | width = None 36 | height = None 37 | scale = None 38 | endian = None 39 | 40 | header = file.readline().decode('utf-8').rstrip() 41 | if header == 'PF': 42 | color = True 43 | elif header == 'Pf': 44 | color = False 45 | else: 46 | raise Exception('Not a PFM file.') 47 | 48 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 49 | if dim_match: 50 | width, height = map(int, dim_match.groups()) 51 | else: 52 | raise Exception('Malformed PFM header.') 53 | 54 | scale = float(file.readline().rstrip()) 55 | if scale < 0: # little-endian 56 | endian = '<' 57 | scale = -scale 58 | else: 59 | endian = '>' # big-endian 60 | 61 | data = np.fromfile(file, endian + 'f') 62 | shape = (height, width, 3) if color else (height, width) 63 | 64 | data = np.reshape(data, shape) 65 | data = np.flipud(data) 66 | return data, scale 67 | -------------------------------------------------------------------------------- /SceneFlow/datasets/flow_transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import random 4 | import numpy as np 5 | import numbers 6 | import pdb 7 | import cv2 8 | 9 | 10 | class Compose(object): 11 | """ Composes several co_transforms together. 12 | """ 13 | 14 | def __init__(self, co_transforms): 15 | self.co_transforms = co_transforms 16 | 17 | def __call__(self, input, target): 18 | for t in self.co_transforms: 19 | input,target = t(input,target) 20 | return input,target 21 | 22 | 23 | 24 | class Scale(object): 25 | """ Rescales the inputs and target arrays to the given 'size'. 26 | """ 27 | 28 | def __init__(self, size, order=2): 29 | self.ratio = size 30 | self.order = order 31 | if order==0: 32 | self.code=cv2.INTER_NEAREST 33 | elif order==1: 34 | self.code=cv2.INTER_LINEAR 35 | elif order==2: 36 | self.code=cv2.INTER_CUBIC 37 | 38 | def __call__(self, inputs, target): 39 | h, w, _ = inputs[0].shape 40 | ratio = self.ratio 41 | 42 | inputs[0] = cv2.resize(inputs[0], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_CUBIC) 43 | inputs[1] = cv2.resize(inputs[1], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_CUBIC) 44 | target = cv2.resize(target, None, fx=ratio,fy=ratio,interpolation=self.code) * ratio 45 | 46 | return inputs, target 47 | 48 | 49 | class RandomCrop(object): 50 | """ Randomly crop images 51 | """ 52 | 53 | def __init__(self, size): 54 | if isinstance(size, numbers.Number): 55 | self.size = (int(size), int(size)) 56 | else: 57 | self.size = size 58 | 59 | def __call__(self, inputs,target): 60 | h, w, _ = inputs[0].shape 61 | th, tw = self.size 62 | if w < tw: tw=w 63 | if h < th: th=h 64 | 65 | x1 = random.randint(0, w - tw) 66 | y1 = random.randint(0, h - th) 67 | inputs[0] = inputs[0][y1: y1 + th,x1: x1 + tw] 68 | inputs[1] = inputs[1][y1: y1 + th,x1: x1 + tw] 69 | return inputs, target[y1: y1 + th,x1: x1 + tw] 70 | 71 | 72 | class RandomVdisp(object): 73 | """Random vertical disparity augmentation 74 | """ 75 | 76 | def __init__(self, angle, px, diff_angle=0, order=2, reshape=False): 77 | self.angle = angle 78 | self.reshape = reshape 79 | self.order = order 80 | self.diff_angle = diff_angle 81 | self.px = px 82 | 83 | def __call__(self, inputs,target): 84 | px2 = random.uniform(-self.px,self.px) 85 | angle2 = random.uniform(-self.angle,self.angle) 86 | 87 | image_center = (np.random.uniform(0,inputs[1].shape[0]),\ 88 | np.random.uniform(0,inputs[1].shape[1])) 89 | rot_mat = cv2.getRotationMatrix2D(image_center, angle2, 1.0) 90 | inputs[1] = cv2.warpAffine(inputs[1], rot_mat, inputs[1].shape[1::-1], flags=cv2.INTER_LINEAR) 91 | trans_mat = np.float32([[1,0,0],[0,1,px2]]) 92 | inputs[1] = cv2.warpAffine(inputs[1], trans_mat, inputs[1].shape[1::-1], flags=cv2.INTER_LINEAR) 93 | return inputs,target 94 | -------------------------------------------------------------------------------- /SceneFlow/datasets/kitti_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | import numpy as np 6 | from datasets.data_io import get_transform, read_all_lines 7 | from . import flow_transforms 8 | import torchvision 9 | 10 | 11 | class KITTIDataset(Dataset): 12 | def __init__(self, datapath, list_filename, training): 13 | self.datapath = datapath 14 | self.left_filenames, self.right_filenames, self.disp_filenames = self.load_path(list_filename) 15 | self.training = training 16 | if self.training: 17 | assert self.disp_filenames is not None 18 | 19 | def load_path(self, list_filename): 20 | lines = read_all_lines(list_filename) 21 | splits = [line.split() for line in lines] 22 | left_images = [x[0] for x in splits] 23 | right_images = [x[1] for x in splits] 24 | if len(splits[0]) == 2: # ground truth not available 25 | return left_images, right_images, None 26 | else: 27 | disp_images = [x[2] for x in splits] 28 | return left_images, right_images, disp_images 29 | 30 | def load_image(self, filename): 31 | return Image.open(filename).convert('RGB') 32 | 33 | def load_disp(self, filename): 34 | data = Image.open(filename) 35 | data = np.array(data, dtype=np.float32) / 256. 36 | return data 37 | 38 | def __len__(self): 39 | return len(self.left_filenames) 40 | 41 | def __getitem__(self, index): 42 | left_img = self.load_image(os.path.join(self.datapath, self.left_filenames[index])) 43 | right_img = self.load_image(os.path.join(self.datapath, self.right_filenames[index])) 44 | 45 | if self.disp_filenames: # has disparity ground truth 46 | disparity = self.load_disp(os.path.join(self.datapath, self.disp_filenames[index])) 47 | else: 48 | disparity = None 49 | 50 | if self.training: 51 | th, tw = 256, 512 52 | #th, tw = 320, 1216 53 | #th, tw = 320, 704 54 | random_brightness = np.random.uniform(0.5, 2.0, 2) 55 | random_gamma = np.random.uniform(0.8, 1.2, 2) 56 | random_contrast = np.random.uniform(0.8, 1.2, 2) 57 | left_img = torchvision.transforms.functional.adjust_brightness(left_img, random_brightness[0]) 58 | left_img = torchvision.transforms.functional.adjust_gamma(left_img, random_gamma[0]) 59 | left_img = torchvision.transforms.functional.adjust_contrast(left_img, random_contrast[0]) 60 | right_img = torchvision.transforms.functional.adjust_brightness(right_img, random_brightness[1]) 61 | right_img = torchvision.transforms.functional.adjust_gamma(right_img, random_gamma[1]) 62 | right_img = torchvision.transforms.functional.adjust_contrast(right_img, random_contrast[1]) 63 | right_img = np.asarray(right_img) 64 | left_img = np.asarray(left_img) 65 | 66 | # w, h = left_img.size 67 | # th, tw = 256, 512 68 | # 69 | # x1 = random.randint(0, w - tw) 70 | # y1 = random.randint(0, h - th) 71 | # 72 | # left_img = left_img.crop((x1, y1, x1 + tw, y1 + th)) 73 | # right_img = right_img.crop((x1, y1, x1 + tw, y1 + th)) 74 | # dataL = dataL[y1:y1 + th, x1:x1 + tw] 75 | # right_img = np.asarray(right_img) 76 | # left_img = np.asarray(left_img) 77 | 78 | # geometric unsymmetric-augmentation 79 | angle = 0 80 | px = 0 81 | if np.random.binomial(1, 0.5): 82 | # angle = 0.1; 83 | # px = 2 84 | angle = 0.05 85 | px = 1 86 | co_transform = flow_transforms.Compose([ 87 | flow_transforms.RandomVdisp(angle, px), 88 | #flow_transforms.Scale(np.random.uniform(self.rand_scale[0], self.rand_scale[1]), order=self.order), 89 | flow_transforms.RandomCrop((th, tw)), 90 | ]) 91 | augmented, disparity = co_transform([left_img, right_img], disparity) 92 | left_img = augmented[0] 93 | right_img = augmented[1] 94 | 95 | right_img.flags.writeable = True 96 | if np.random.binomial(1,0.2): 97 | sx = int(np.random.uniform(35,100)) 98 | sy = int(np.random.uniform(25,75)) 99 | cx = int(np.random.uniform(sx,right_img.shape[0]-sx)) 100 | cy = int(np.random.uniform(sy,right_img.shape[1]-sy)) 101 | right_img[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(right_img,0),0)[np.newaxis,np.newaxis] 102 | 103 | # to tensor, normalize 104 | disparity = np.ascontiguousarray(disparity, dtype=np.float32) 105 | processed = get_transform() 106 | left_img = processed(left_img) 107 | right_img = processed(right_img) 108 | 109 | return {"left": left_img, 110 | "right": right_img, 111 | "disparity": disparity} 112 | else: 113 | w, h = left_img.size 114 | 115 | # normalize 116 | processed = get_transform() 117 | left_img = processed(left_img).numpy() 118 | right_img = processed(right_img).numpy() 119 | 120 | # pad to size 1248x384 121 | top_pad = 384 - h 122 | right_pad = 1248 - w 123 | assert top_pad > 0 and right_pad > 0 124 | # pad images 125 | left_img = np.lib.pad(left_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='constant', constant_values=0) 126 | right_img = np.lib.pad(right_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='constant', 127 | constant_values=0) 128 | # pad disparity gt 129 | if disparity is not None: 130 | assert len(disparity.shape) == 2 131 | disparity = np.lib.pad(disparity, ((top_pad, 0), (0, right_pad)), mode='constant', constant_values=0) 132 | 133 | if disparity is not None: 134 | return {"left": left_img, 135 | "right": right_img, 136 | "disparity": disparity, 137 | "top_pad": top_pad, 138 | "right_pad": right_pad, 139 | "left_filename": self.left_filenames[index]} 140 | else: 141 | return {"left": left_img, 142 | "right": right_img, 143 | "top_pad": top_pad, 144 | "right_pad": right_pad, 145 | "left_filename": self.left_filenames[index], 146 | "right_filename": self.right_filenames[index]} 147 | -------------------------------------------------------------------------------- /SceneFlow/datasets/kitti_dataset_1215.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | import numpy as np 6 | import cv2 7 | from datasets.data_io import get_transform, read_all_lines, pfm_imread 8 | import torchvision.transforms as transforms 9 | import torch 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | class KITTIDataset1215(Dataset): 14 | def __init__(self, kitti15_datapath, kitti12_datapath, list_filename, training): 15 | self.datapath_15 = kitti15_datapath 16 | self.datapath_12 = kitti12_datapath 17 | self.left_filenames, self.right_filenames, self.disp_filenames, self.pesu = self.load_path(list_filename) 18 | self.training = training 19 | if self.training: 20 | assert self.disp_filenames is not None 21 | 22 | def load_path(self, list_filename): 23 | lines = read_all_lines(list_filename) 24 | splits = [line.split() for line in lines] 25 | left_images = [x[0] for x in splits] 26 | right_images = [x[1] for x in splits] 27 | if len(splits[0]) == 2: # ground truth not available 28 | return left_images, right_images, None 29 | else: 30 | disp_images = [x[2] for x in splits] 31 | if "image" in left_images[0]: 32 | pesu_images = [x.replace('disp_occ_0', 'disp_occ_0_pseudo_gt') for x in disp_images] 33 | else: 34 | pesu_images = [x.replace('disp_occ', 'disp_occ_pseudo_gt') for x in disp_images] 35 | return left_images, right_images, disp_images, pesu_images 36 | 37 | def load_image(self, filename): 38 | return Image.open(filename).convert('RGB') 39 | 40 | def load_disp(self, filename): 41 | data = Image.open(filename) 42 | data = np.array(data, dtype=np.float32) / 256. 43 | return data 44 | 45 | def __len__(self): 46 | return len(self.left_filenames) 47 | 48 | def __getitem__(self, index): 49 | 50 | left_name = self.left_filenames[index].split('/')[1] 51 | if left_name.startswith('image'): 52 | self.datapath = self.datapath_15 53 | else: 54 | self.datapath = self.datapath_12 55 | 56 | left_img = self.load_image(os.path.join(self.datapath, self.left_filenames[index])) 57 | right_img = self.load_image(os.path.join(self.datapath, self.right_filenames[index])) 58 | 59 | if self.disp_filenames: # has disparity ground truth 60 | disparity = self.load_disp(os.path.join(self.datapath, self.disp_filenames[index])) 61 | pesu = self.load_disp(os.path.join(self.datapath, self.pesu[index])) 62 | else: 63 | disparity = None 64 | 65 | if self.training: 66 | w, h = left_img.size 67 | crop_w, crop_h = 512, 256 68 | 69 | x1 = random.randint(0, w - crop_w) 70 | if random.randint(0, 10) >= int(8): 71 | y1 = random.randint(0, h - crop_h) 72 | else: 73 | y1 = random.randint(int(0.3 * h), h - crop_h) 74 | 75 | # random crop 76 | left_img = left_img.crop((x1, y1, x1 + crop_w, y1 + crop_h)) 77 | right_img = right_img.crop((x1, y1, x1 + crop_w, y1 + crop_h)) 78 | disparity = disparity[y1:y1 + crop_h, x1:x1 + crop_w] 79 | pesu = pesu[y1:y1 + crop_h, x1:x1 + crop_w] 80 | 81 | # to tensor, normalize 82 | processed = get_transform() 83 | left_img = processed(left_img) 84 | right_img = processed(right_img) 85 | 86 | return {"left": left_img, 87 | "right": right_img, 88 | "disparity": disparity, 89 | "disp_pesu": pesu} 90 | 91 | else: 92 | w, h = left_img.size 93 | 94 | # normalize 95 | processed = get_transform() 96 | left_img = processed(left_img).numpy() 97 | right_img = processed(right_img).numpy() 98 | 99 | # pad to size 1248x384 100 | top_pad = 384 - h 101 | right_pad = 1248 - w 102 | assert top_pad > 0 and right_pad > 0 103 | # pad images 104 | left_img = np.lib.pad(left_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='constant', constant_values=0) 105 | right_img = np.lib.pad(right_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='constant', 106 | constant_values=0) 107 | # pad disparity gt 108 | if disparity is not None: 109 | assert len(disparity.shape) == 2 110 | disparity = np.lib.pad(disparity, ((top_pad, 0), (0, right_pad)), mode='constant', constant_values=0) 111 | 112 | 113 | if disparity is not None: 114 | return {"left": left_img, 115 | "right": right_img, 116 | "disparity": disparity, 117 | "left_filename": self.left_filenames[index], 118 | "top_pad": top_pad, 119 | "right_pad": right_pad} 120 | else: 121 | return {"left": left_img, 122 | "right": right_img, 123 | "top_pad": top_pad, 124 | "right_pad": right_pad, 125 | "left_filename": self.left_filenames[index], 126 | "right_filename": self.right_filenames[index]} 127 | 128 | -------------------------------------------------------------------------------- /SceneFlow/datasets/sceneflow_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | import numpy as np 6 | from datasets.data_io import get_transform, read_all_lines, pfm_imread 7 | 8 | 9 | class SceneFlowDatset(Dataset): 10 | def __init__(self, datapath, list_filename, training): 11 | self.datapath = datapath 12 | self.left_filenames, self.right_filenames, self.disp_filenames = self.load_path(list_filename) 13 | self.training = training 14 | 15 | def load_path(self, list_filename): 16 | lines = read_all_lines(list_filename) 17 | splits = [line.split() for line in lines] 18 | left_images = [x[0] for x in splits] 19 | right_images = [x[1] for x in splits] 20 | disp_images = [x[2] for x in splits] 21 | return left_images, right_images, disp_images 22 | 23 | def load_image(self, filename): 24 | return Image.open(filename).convert('RGB') 25 | 26 | def load_disp(self, filename): 27 | data, scale = pfm_imread(filename) 28 | data = np.ascontiguousarray(data, dtype=np.float32) 29 | return data 30 | 31 | def __len__(self): 32 | return len(self.left_filenames) 33 | 34 | def __getitem__(self, index): 35 | left_img = self.load_image(os.path.join(self.datapath, self.left_filenames[index])) 36 | right_img = self.load_image(os.path.join(self.datapath, self.right_filenames[index])) 37 | disparity = self.load_disp(os.path.join(self.datapath, self.disp_filenames[index])) 38 | 39 | if self.training: 40 | w, h = left_img.size 41 | crop_w, crop_h = 512, 256 42 | 43 | x1 = random.randint(0, w - crop_w) 44 | y1 = random.randint(0, h - crop_h) 45 | 46 | # random crop 47 | left_img = left_img.crop((x1, y1, x1 + crop_w, y1 + crop_h)) 48 | right_img = right_img.crop((x1, y1, x1 + crop_w, y1 + crop_h)) 49 | disparity = disparity[y1:y1 + crop_h, x1:x1 + crop_w] 50 | 51 | # to tensor, normalize 52 | processed = get_transform() 53 | left_img = processed(left_img) 54 | right_img = processed(right_img) 55 | 56 | return {"left": left_img, 57 | "right": right_img, 58 | "disparity": disparity} 59 | else: 60 | w, h = left_img.size 61 | crop_w, crop_h = 960, 512 62 | 63 | left_img = left_img.crop((w - crop_w, h - crop_h, w, h)) 64 | right_img = right_img.crop((w - crop_w, h - crop_h, w, h)) 65 | disparity = disparity[h - crop_h:h, w - crop_w: w] 66 | 67 | processed = get_transform() 68 | left_img = processed(left_img) 69 | right_img = processed(right_img) 70 | 71 | return {"left": left_img, 72 | "right": right_img, 73 | "disparity": disparity, 74 | "top_pad": 0, 75 | "right_pad": 0, 76 | "left_filename": self.left_filenames[index]} 77 | -------------------------------------------------------------------------------- /SceneFlow/main.py: -------------------------------------------------------------------------------- 1 | # from __future__ import print_function, division 2 | import argparse 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim as optim 9 | import torch.utils.data 10 | from torch.autograd import Variable 11 | import torchvision.utils as vutils 12 | import torch.nn.functional as F 13 | import numpy as np 14 | import time 15 | # from tensorboardX import SummaryWriter 16 | from datasets import __datasets__ 17 | from models import __models__, model_loss_train_attn_only, model_loss_train_freeze_attn, model_loss_train, model_loss_test 18 | from utils import * 19 | from torch.utils.data import DataLoader 20 | import gc 21 | # from apex import amp 22 | import cv2 23 | 24 | cudnn.benchmark = True 25 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5' 26 | 27 | parser = argparse.ArgumentParser(description='Attention Concatenation Volume for Accurate and Efficient Stereo Matching (ACVNet)') 28 | parser.add_argument('--model', default='acvnet_ddim', help='select a model structure', choices=__models__.keys()) 29 | parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity') 30 | parser.add_argument('--dataset', default='sceneflow', help='dataset name', choices=__datasets__.keys()) 31 | parser.add_argument('--datapath', default="/mnt/Datasets/Sceneflow/", help='data path') 32 | parser.add_argument('--trainlist', default='./filenames/sceneflow_train.txt', help='training list') 33 | parser.add_argument('--testlist',default='./filenames/sceneflow_test.txt', help='testing list') 34 | parser.add_argument('--lr', type=float, default=0.001, help='base learning rate') 35 | parser.add_argument('--batch_size', type=int, default=23, help='training batch size') 36 | parser.add_argument('--test_batch_size', type=int, default=16, help='testing batch size') 37 | parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train') 38 | parser.add_argument('--lrepochs',default="16,24,32,40,48:2", type=str, help='the epochs to decay lr: the downscale rate') 39 | parser.add_argument('--attention_weights_only', default=False, type=str, help='only train attention weights') 40 | parser.add_argument('--freeze_attention_weights', default=False, type=str, help='freeze attention weights parameters') 41 | parser.add_argument('--logdir',default='./checkpoints/', help='the directory to save logs and checkpoints') 42 | parser.add_argument('--loadckpt', default='./pretrained_model/sceneflow.ckpt',help='load the weights from a specific checkpoint') 43 | parser.add_argument('--resume', action='store_true', help='continue training the model') 44 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') 45 | parser.add_argument('--summary_freq', type=int, default=20, help='the frequency of saving summary') 46 | parser.add_argument('--save_freq', type=int, default=1, help='the frequency of saving checkpoint') 47 | 48 | # parse arguments, set seeds 49 | args = parser.parse_args() 50 | torch.manual_seed(args.seed) 51 | torch.cuda.manual_seed(args.seed) 52 | os.makedirs(args.logdir, exist_ok=True) 53 | 54 | # create summary logger 55 | print("creating new summary file") 56 | # logger = SummaryWriter(args.logdir) 57 | 58 | # dataset, dataloader 59 | StereoDataset = __datasets__[args.dataset] 60 | train_dataset = StereoDataset(args.datapath, args.trainlist, True) 61 | test_dataset = StereoDataset(args.datapath, args.testlist, False) 62 | TrainImgLoader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=16, drop_last=True) 63 | TestImgLoader = DataLoader(test_dataset, args.test_batch_size, shuffle=False, num_workers=16, drop_last=False) 64 | 65 | # model, optimizer 66 | model = __models__[args.model](args.maxdisp, args.attention_weights_only, args.freeze_attention_weights) 67 | model = nn.DataParallel(model) 68 | model.cuda() 69 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) 70 | 71 | # load parameters 72 | start_epoch = 0 73 | if args.resume: 74 | # find all checkpoints file and sort according to epoch id 75 | all_saved_ckpts = [fn for fn in os.listdir(args.logdir) if fn.endswith(".ckpt")] 76 | all_saved_ckpts = sorted(all_saved_ckpts, key=lambda x: int(x.split('_')[-1].split('.')[0])) 77 | # use the latest checkpoint file 78 | loadckpt = os.path.join(args.logdir, all_saved_ckpts[-1]) 79 | print("loading the lastest model in logdir: {}".format(loadckpt)) 80 | state_dict = torch.load(loadckpt) 81 | model.load_state_dict(state_dict['model']) 82 | optimizer.load_state_dict(state_dict['optimizer']) 83 | start_epoch = state_dict['epoch'] + 1 84 | elif args.loadckpt: 85 | # load the checkpoint file specified by args.loadckpt 86 | print("loading model {}".format(args.loadckpt)) 87 | state_dict = torch.load(args.loadckpt) 88 | model_dict = model.state_dict() 89 | pre_dict = {k: v for k, v in state_dict['model'].items() if k in model_dict} 90 | model_dict.update(pre_dict) 91 | model.load_state_dict(model_dict) 92 | 93 | print("start at epoch {}".format(start_epoch)) 94 | 95 | 96 | def train(): 97 | for epoch_idx in range(start_epoch, args.epochs): 98 | adjust_learning_rate(optimizer, epoch_idx, args.lr, args.lrepochs) 99 | all_loss = 0 100 | # training 101 | for batch_idx, sample in enumerate(TrainImgLoader): 102 | global_step = len(TrainImgLoader) * epoch_idx + batch_idx 103 | start_time = time.time() 104 | do_summary = global_step % args.summary_freq == 0 105 | loss, scalar_outputs, image_outputs = train_sample(sample, compute_metrics=False) 106 | all_loss += loss 107 | # if do_summary: 108 | # save_scalars(logger, 'train', scalar_outputs, global_step) 109 | # save_images(logger, 'train', image_outputs, global_step) 110 | del scalar_outputs, image_outputs 111 | print('Epoch {}/{}, Iter {}/{}, train loss = {:.3f}, time = {:.3f}'.format(epoch_idx, args.epochs, 112 | batch_idx, 113 | len(TrainImgLoader), loss, 114 | time.time() - start_time)) 115 | print('Epoch {}/{}, train loss = {:.3f}'.format(epoch_idx, args.epochs, all_loss)) 116 | # saving checkpoints 117 | 118 | if (epoch_idx + 1) % args.save_freq == 0: 119 | checkpoint_data = {'epoch': epoch_idx, 'model': model.state_dict(), 'optimizer': optimizer.state_dict()} 120 | #id_epoch = (epoch_idx + 1) % 100 121 | torch.save(checkpoint_data, "{}/checkpoint_{:0>6}.ckpt".format(args.logdir, epoch_idx)) 122 | gc.collect() 123 | 124 | 125 | # train one sample 126 | def train_sample(sample, compute_metrics=False): 127 | model.train() 128 | imgL, imgR, disp_gt = sample['left'], sample['right'], sample['disparity'] 129 | imgL = imgL.cuda() 130 | imgR = imgR.cuda() 131 | disp_gt = disp_gt.cuda() 132 | disp_net = torch.clamp(disp_gt, 0, args.maxdisp-1).unsqueeze(1) 133 | b, c, h, w = disp_net.shape 134 | disp_net = F.interpolate(disp_net, size=(h//4, w//4), mode='bilinear') / 4 135 | optimizer.zero_grad() 136 | disp_ests = model(imgL, imgR, None, disp_net, None) 137 | mask = (disp_gt < args.maxdisp) & (disp_gt > 0) 138 | if args.attention_weights_only: 139 | loss = model_loss_train_attn_only(disp_ests, disp_gt, mask) 140 | elif args.freeze_attention_weights: 141 | loss = model_loss_train_freeze_attn(disp_ests, disp_gt, mask) 142 | else: 143 | loss = model_loss_train(disp_ests, disp_gt, mask) 144 | scalar_outputs = {"loss": loss} 145 | image_outputs = {"disp_est": disp_ests, "disp_gt": disp_gt, "imgL": imgL, "imgR": imgR} 146 | if compute_metrics: 147 | with torch.no_grad(): 148 | image_outputs["errormap"] = [disp_error_image_func.apply(disp_est, disp_gt) for disp_est in disp_ests] 149 | scalar_outputs["EPE"] = [EPE_metric(disp_est, disp_gt, mask) for disp_est in disp_ests] 150 | scalar_outputs["D1"] = [D1_metric(disp_est, disp_gt, mask) for disp_est in disp_ests] 151 | scalar_outputs["Thres1"] = [Thres_metric(disp_est, disp_gt, mask, 1.0) for disp_est in disp_ests] 152 | scalar_outputs["Thres2"] = [Thres_metric(disp_est, disp_gt, mask, 2.0) for disp_est in disp_ests] 153 | scalar_outputs["Thres3"] = [Thres_metric(disp_est, disp_gt, mask, 3.0) for disp_est in disp_ests] 154 | loss.backward() 155 | optimizer.step() 156 | return tensor2float(loss), tensor2float(scalar_outputs), image_outputs 157 | 158 | if __name__ == '__main__': 159 | train() 160 | -------------------------------------------------------------------------------- /SceneFlow/models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.acv import ACVNet 2 | from models.acv_ddim import ACVNet_DDIM 3 | from models.loss import model_loss_train_attn_only, model_loss_train_freeze_attn, model_loss_train, model_loss_test 4 | 5 | __models__ = { 6 | "acvnet": ACVNet, 7 | "acvnet_ddim": ACVNet_DDIM, 8 | } 9 | -------------------------------------------------------------------------------- /SceneFlow/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/models/__pycache__/acv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/models/__pycache__/acv.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/models/__pycache__/acv_ddim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/models/__pycache__/acv_ddim.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/models/__pycache__/acv_ddim_lowD.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/models/__pycache__/acv_ddim_lowD.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/models/__pycache__/acv_ddpm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/models/__pycache__/acv_ddpm.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/models/__pycache__/head.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/models/__pycache__/head.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/models/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/models/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/models/__pycache__/pwcnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/models/__pycache__/pwcnet.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/models/__pycache__/submodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/models/__pycache__/submodule.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/models/head.py: -------------------------------------------------------------------------------- 1 | """ 2 | DiffusionDet Transformer class. 3 | 4 | Copy-paste from torch.nn.Transformer with modifications: 5 | * positional encodings are passed in MHattention 6 | * extra LN at the end of encoder is removed 7 | * decoder returns a stack of activations from all decoding layers 8 | """ 9 | import copy 10 | import math 11 | 12 | import numpy as np 13 | import torch 14 | from torch import nn, Tensor 15 | import torch.nn.functional as F 16 | 17 | 18 | 19 | _DEFAULT_SCALE_CLAMP = math.log(100000.0 / 16) 20 | 21 | 22 | class SinusoidalPositionEmbeddings(nn.Module): 23 | def __init__(self, dim): 24 | super().__init__() 25 | self.dim = dim 26 | 27 | def forward(self, time): 28 | device = time.device 29 | half_dim = self.dim // 2 30 | embeddings = math.log(10000) / (half_dim - 1) 31 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 32 | embeddings = time[:, None] * embeddings[None, :] 33 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 34 | return embeddings 35 | 36 | 37 | class GaussianFourierProjection(nn.Module): 38 | """Gaussian random features for encoding time steps.""" 39 | 40 | def __init__(self, embed_dim, scale=30.): 41 | super().__init__() 42 | # Randomly sample weights during initialization. These weights are fixed 43 | # during optimization and are not trainable. 44 | self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False) 45 | 46 | def forward(self, x): 47 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 48 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 49 | 50 | 51 | class DynamicHead(nn.Module): 52 | 53 | def __init__(self, d_model): 54 | super().__init__() 55 | self.d_model = d_model 56 | time_dim = d_model * 4 57 | self.time_mlp = nn.Sequential( 58 | SinusoidalPositionEmbeddings(d_model), 59 | nn.Linear(d_model, time_dim), 60 | nn.GELU(), 61 | nn.Linear(time_dim, time_dim), 62 | ) 63 | self.block_time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(d_model * 4, d_model)) 64 | #self.block_time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(d_model * 4, d_model), nn.Sigmoid()) 65 | 66 | self._reset_parameters() 67 | 68 | def _reset_parameters(self): 69 | # init all parameters. 70 | for p in self.parameters(): 71 | if p.dim() > 1: 72 | nn.init.xavier_uniform_(p) 73 | 74 | def forward(self, noisy, t): 75 | time_emb = self.time_mlp(t) 76 | scale_shift = self.block_time_mlp(time_emb).unsqueeze(-1).unsqueeze(-1) 77 | noisy = noisy + scale_shift 78 | #noisy = noisy * scale_shift 79 | # scale, shift = scale_shift.chunk(2, dim=1) 80 | # volume = volume * (scale + 1) + shift 81 | 82 | return noisy -------------------------------------------------------------------------------- /SceneFlow/models/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | 4 | 5 | def model_loss_train_attn_only(disp_ests, disp_gt, mask): 6 | weights = [1.0] 7 | all_losses = [] 8 | for disp_est, weight in zip(disp_ests, weights): 9 | all_losses.append(weight * F.smooth_l1_loss(disp_est[mask], disp_gt[mask], size_average=True)) 10 | return sum(all_losses) 11 | 12 | def model_loss_train_freeze_attn(disp_ests, disp_gt, mask): 13 | weights = [0.5, 0.7, 1.0] 14 | all_losses = [] 15 | for disp_est, weight in zip(disp_ests, weights): 16 | all_losses.append(weight * F.smooth_l1_loss(disp_est[mask], disp_gt[mask], size_average=True)) 17 | return sum(all_losses) 18 | 19 | def model_loss_train(disp_ests, disp_gt, mask): 20 | weights = [0.5, 0.5, 0.7, 1.0] 21 | all_losses = [] 22 | for disp_est, weight in zip(disp_ests, weights): 23 | all_losses.append(weight * F.smooth_l1_loss(disp_est[mask], disp_gt[mask], size_average=True)) 24 | return sum(all_losses) 25 | 26 | def model_loss_test(disp_ests, disp_gt, mask): 27 | weights = [1.0] 28 | all_losses = [] 29 | for disp_est, weight in zip(disp_ests, weights): 30 | all_losses.append(weight * F.l1_loss(disp_est[mask], disp_gt[mask], size_average=True)) 31 | return sum(all_losses) 32 | -------------------------------------------------------------------------------- /SceneFlow/save_disp_sceneflow.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import argparse 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim as optim 9 | import torch.utils.data 10 | from torch.autograd import Variable 11 | import torchvision.utils as vutils 12 | import torch.nn.functional as F 13 | import numpy as np 14 | import time 15 | # from tensorboardX import SummaryWriter 16 | from datasets import __datasets__ 17 | from models import __models__ 18 | from utils import * 19 | from torch.utils.data import DataLoader 20 | import gc 21 | import matplotlib.pyplot as plt 22 | import skimage 23 | import skimage.io 24 | import cv2 25 | 26 | # cudnn.benchmark = True 27 | 28 | os.environ['CUDA_VISIBLE_DEVICES'] = '4' 29 | 30 | parser = argparse.ArgumentParser( 31 | description='Attention Concatenation Volume for Accurate and Efficient Stereo Matching (ACVNet)') 32 | parser.add_argument('--model', default='acvnet_ddim', help='select a model structure', choices=__models__.keys()) 33 | parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity') 34 | parser.add_argument('--dataset', default='sceneflow', help='dataset name', choices=__datasets__.keys()) 35 | parser.add_argument('--datapath', default="/home/zhengdian/dataset/Sceneflow/", help='data path') 36 | parser.add_argument('--testlist', default='./filenames/test_temp.txt', help='testing list') 37 | parser.add_argument('--loadckpt', default='/home/zhengdian/code/ACVNet-main/checkpoints/checkpoint_000046.ckpt') 38 | # parse arguments 39 | args = parser.parse_args() 40 | 41 | # dataset, dataloader 42 | StereoDataset = __datasets__[args.dataset] 43 | test_dataset = StereoDataset(args.datapath, args.testlist, False) 44 | TestImgLoader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4, drop_last=False) 45 | 46 | # model, optimizer 47 | model = __models__[args.model](args.maxdisp, False, False) 48 | model = nn.DataParallel(model) 49 | model.cuda() 50 | 51 | model_origin = __models__['acvnet'](args.maxdisp, False, False) 52 | model_origin = nn.DataParallel(model_origin) 53 | model_origin.cuda() 54 | 55 | # load parameters 56 | print("loading model {}".format(args.loadckpt)) 57 | state_dict = torch.load(args.loadckpt) 58 | model.load_state_dict(state_dict['model']) 59 | 60 | state_dict = torch.load('/home/zhengdian/code/ACVNet-main/pretrained_model/sceneflow.ckpt') 61 | model_origin.load_state_dict(state_dict['model']) 62 | 63 | save_dir = '/home/zhengdian/code/ACVNet-main/temp_c/' 64 | 65 | 66 | def test(): 67 | os.makedirs(save_dir, exist_ok=True) 68 | for batch_idx, sample in enumerate(TestImgLoader): 69 | torch.cuda.synchronize() 70 | start_time = time.time() 71 | disp_est_np = tensor2numpy(test_sample(sample)) 72 | torch.cuda.synchronize() 73 | print('Iter {}/{}, time = {:3f}'.format(batch_idx, len(TestImgLoader), 74 | time.time() - start_time)) 75 | left_filenames = sample["left_filename"] 76 | top_pad_np = tensor2numpy(sample["top_pad"]) 77 | right_pad_np = tensor2numpy(sample["right_pad"]) 78 | 79 | for disp_est, top_pad, right_pad, fn in zip(disp_est_np, top_pad_np, right_pad_np, left_filenames): 80 | assert len(disp_est.shape) == 2 81 | #disp_est = np.array(disp_est[top_pad:, :-right_pad], dtype=np.float32) 82 | disp_est = np.array(disp_est, dtype=np.float32) 83 | fil = os.path.join(save_dir, fn.split('/')[-4]) 84 | fil = os.path.join(fil, fn.split('/')[-3]) 85 | os.makedirs(fil, exist_ok=True) 86 | fil = os.path.join(fil, fn.split('/')[-1]) 87 | print("saving to", fil, disp_est.shape) 88 | disp_est_uint = np.round(disp_est * 255).astype(np.uint16) 89 | #skimage.io.imsave(fil, disp_est_uint) 90 | plt.imsave(fil, disp_est_uint, cmap='jet') 91 | #cv2.imwrite(fn,disp_est_uint, ) 92 | # cv2.imwrite(fn, cv2.applyColorMap(cv2.convertScaleAbs(disp_est_uint, alpha=0.008), cv2.COLORMAP_JET)) 93 | 94 | 95 | # test one sample 96 | @make_nograd_func 97 | def test_sample(sample): 98 | model.eval() 99 | model_origin.eval() 100 | imgL, imgR, disp_gt, filename = sample['left'], sample['right'], sample['disparity'], sample['left_filename'] 101 | imgL = imgL.cuda() 102 | imgR = imgR.cuda() 103 | disp_gt = disp_gt.cuda() 104 | 105 | # disp_ests = model_origin(imgL, imgR) 106 | disp_ = model_origin(imgL, imgR)[-1] 107 | disp_net = torch.clamp(disp_, 0, args.maxdisp - 1).unsqueeze(1) 108 | 109 | b, c, h, w = disp_net.shape 110 | disp_net = F.interpolate(disp_net, size=(h // 4, w // 4), mode='bilinear') / 4 111 | 112 | disp_ests = model(imgL, imgR, disp_, disp_net, None) 113 | return disp_ests[-1] 114 | # return disp_gt 115 | 116 | 117 | if __name__ == '__main__': 118 | test() 119 | -------------------------------------------------------------------------------- /SceneFlow/test_sceneflow_ddim.py: -------------------------------------------------------------------------------- 1 | # from __future__ import print_function, division 2 | import argparse 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim as optim 9 | import torch.utils.data 10 | from torch.autograd import Variable 11 | import torchvision.utils as vutils 12 | import torch.nn.functional as F 13 | import numpy as np 14 | import time 15 | # from tensorboardX import SummaryWriter 16 | from datasets import __datasets__ 17 | from models import __models__, model_loss_train_attn_only, model_loss_train_freeze_attn, model_loss_train, model_loss_test 18 | from utils import * 19 | from models.submodule import * 20 | from datasets.data_io import get_transform, read_all_lines, pfm_imread 21 | from torch.utils.data import DataLoader 22 | from torchvision.utils import save_image 23 | import gc 24 | import matplotlib.pyplot as plt 25 | # from apex import amp 26 | import cv2 27 | from thop import profile 28 | from thop import clever_format 29 | 30 | cudnn.benchmark = True 31 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 32 | 33 | parser = argparse.ArgumentParser(description='Attention Concatenation Volume for Accurate and Efficient Stereo Matching (ACVNet)') 34 | parser.add_argument('--model', default='acvnet_ddim', help='select a model structure', choices=__models__.keys()) 35 | parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity') 36 | parser.add_argument('--dataset', default='sceneflow', help='dataset name', choices=__datasets__.keys()) 37 | parser.add_argument('--datapath', default="/mnt/Datasets/Sceneflow/", help='data path') 38 | parser.add_argument('--testlist',default='./filenames/sceneflow_test.txt', help='testing list') 39 | parser.add_argument('--test_batch_size', type=int, default=1, help='testing batch size') 40 | parser.add_argument('--loadckpt', default='checkpoints/checkpoint_000046.ckpt') 41 | 42 | # parse arguments, set seeds 43 | args = parser.parse_args() 44 | 45 | # dataset, dataloader 46 | StereoDataset = __datasets__[args.dataset] 47 | test_dataset = StereoDataset(args.datapath, args.testlist, False) 48 | TestImgLoader = DataLoader(test_dataset, args.test_batch_size, shuffle=False, num_workers=4, drop_last=False) 49 | 50 | # model, optimizer 51 | model = __models__[args.model](args.maxdisp, False, False) 52 | total = sum([param.nelement() for param in model.parameters()]) 53 | print("Number of parameter our: %.2fM" % (total/1e6)) 54 | model = nn.DataParallel(model) 55 | model.cuda() 56 | 57 | model_origin = __models__['acvnet'](args.maxdisp, False, False) 58 | total = sum([param.nelement() for param in model_origin.parameters()]) 59 | print("Number of parameter origin: %.2fM" % (total/1e6)) 60 | 61 | model_origin = nn.DataParallel(model_origin) 62 | model_origin.cuda() 63 | 64 | # # load parameters 65 | print("loading model {}".format(args.loadckpt)) 66 | state_dict = torch.load(args.loadckpt) 67 | model.load_state_dict(state_dict['model']) 68 | 69 | 70 | state_dict = torch.load('pretrained_model/sceneflow.ckpt') 71 | model_origin.load_state_dict(state_dict['model']) 72 | 73 | def test(): 74 | avg_test_scalars = AverageMeterDict() 75 | for batch_idx, sample in enumerate(TestImgLoader): 76 | start_time = time.time() 77 | loss, scalar_outputs = test_sample(sample) 78 | avg_test_scalars.update(scalar_outputs) 79 | del scalar_outputs 80 | print('Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(batch_idx, 81 | len(TestImgLoader), loss, 82 | time.time() - start_time)) 83 | 84 | avg_test_scalars = avg_test_scalars.mean() 85 | print("avg_test_scalars", avg_test_scalars) 86 | 87 | 88 | # test one sample 89 | @make_nograd_func 90 | def test_sample(sample): 91 | model.eval() 92 | model_origin.eval() 93 | imgL, imgR, disp_gt, filename = sample['left'], sample['right'], sample['disparity'], sample['left_filename'] 94 | imgL = imgL.cuda() 95 | imgR = imgR.cuda() 96 | disp_gt = disp_gt.cuda() 97 | 98 | mask_gt = (disp_gt < args.maxdisp) & (disp_gt > 0) 99 | 100 | #disp_ests = model_origin(imgL, imgR) 101 | disp_ = model_origin(imgL, imgR)[-1] 102 | 103 | disp_net = torch.clamp(disp_, 0, args.maxdisp - 1).unsqueeze(1) 104 | 105 | b, c, h, w = disp_net.shape 106 | disp_net = F.interpolate(disp_net, size=(h // 4, w // 4), mode='bilinear') / 4 107 | 108 | disp_ests = model(imgL, imgR, disp_, disp_net, None) 109 | 110 | disp_gts = [disp_gt] 111 | loss = model_loss_test(disp_ests, disp_gt, mask_gt) 112 | scalar_outputs = {"loss": loss} 113 | scalar_outputs["EPE"] = [EPE_metric(disp_est, disp_gt, mask_gt) for disp_est in disp_ests] 114 | scalar_outputs["D1"] = [D1_metric(disp_est, disp_gt, mask_gt) for disp_est in disp_ests] 115 | scalar_outputs["Thres1"] = [Thres_metric(disp_est, disp_gt, mask_gt, 1.0) for disp_est in disp_ests] 116 | scalar_outputs["Thres2"] = [Thres_metric(disp_est, disp_gt, mask_gt, 2.0) for disp_est in disp_ests] 117 | scalar_outputs["Thres3"] = [Thres_metric(disp_est, disp_gt, mask_gt, 3.0) for disp_est in disp_ests] 118 | 119 | # if scalar_outputs["EPE"][0] > 1: 120 | # print(filename) 121 | # raise 122 | return tensor2float(loss), tensor2float(scalar_outputs) 123 | 124 | if __name__ == '__main__': 125 | test() 126 | -------------------------------------------------------------------------------- /SceneFlow/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.experiment import * 2 | from utils.visualization import * 3 | from utils.metrics import D1_metric, Thres_metric, EPE_metric, EPE_metric_mask, Thres_metric_mask, D1_metric_mask 4 | from utils.misc import init_distributed_mode -------------------------------------------------------------------------------- /SceneFlow/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/utils/__pycache__/experiment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/utils/__pycache__/experiment.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/utils/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/utils/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/utils/__pycache__/visualization.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/DiffuVolume/df4de31d183cff51a72e2a667e8d20397e55110c/SceneFlow/utils/__pycache__/visualization.cpython-38.pyc -------------------------------------------------------------------------------- /SceneFlow/utils/experiment.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.parallel 5 | import torch.utils.data 6 | from torch.autograd import Variable 7 | import torchvision.utils as vutils 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import copy 11 | 12 | 13 | def make_iterative_func(func): 14 | def wrapper(vars): 15 | if isinstance(vars, list): 16 | return [wrapper(x) for x in vars] 17 | elif isinstance(vars, tuple): 18 | return tuple([wrapper(x) for x in vars]) 19 | elif isinstance(vars, dict): 20 | return {k: wrapper(v) for k, v in vars.items()} 21 | else: 22 | return func(vars) 23 | 24 | return wrapper 25 | 26 | 27 | def make_nograd_func(func): 28 | def wrapper(*f_args, **f_kwargs): 29 | with torch.no_grad(): 30 | ret = func(*f_args, **f_kwargs) 31 | return ret 32 | 33 | return wrapper 34 | 35 | 36 | @make_iterative_func 37 | def tensor2float(vars): 38 | if isinstance(vars, float): 39 | return vars 40 | elif isinstance(vars, torch.Tensor): 41 | return vars.data.item() 42 | else: 43 | raise NotImplementedError("invalid input type for tensor2float") 44 | 45 | 46 | @make_iterative_func 47 | def tensor2numpy(vars): 48 | if isinstance(vars, np.ndarray): 49 | return vars 50 | elif isinstance(vars, torch.Tensor): 51 | return vars.data.cpu().numpy() 52 | else: 53 | raise NotImplementedError("invalid input type for tensor2numpy") 54 | 55 | 56 | @make_iterative_func 57 | def check_allfloat(vars): 58 | assert isinstance(vars, float) 59 | 60 | 61 | def save_scalars(logger, mode_tag, scalar_dict, global_step): 62 | scalar_dict = tensor2float(scalar_dict) 63 | for tag, values in scalar_dict.items(): 64 | if not isinstance(values, list) and not isinstance(values, tuple): 65 | values = [values] 66 | for idx, value in enumerate(values): 67 | scalar_name = '{}/{}'.format(mode_tag, tag) 68 | # if len(values) > 1: 69 | scalar_name = scalar_name + "_" + str(idx) 70 | logger.add_scalar(scalar_name, value, global_step) 71 | 72 | 73 | def save_images(logger, mode_tag, images_dict, global_step): 74 | images_dict = tensor2numpy(images_dict) 75 | for tag, values in images_dict.items(): 76 | if not isinstance(values, list) and not isinstance(values, tuple): 77 | values = [values] 78 | for idx, value in enumerate(values): 79 | if len(value.shape) == 3: 80 | value = value[:, np.newaxis, :, :] 81 | value = value[:1] 82 | value = torch.from_numpy(value) 83 | 84 | image_name = '{}/{}'.format(mode_tag, tag) 85 | if len(values) > 1: 86 | image_name = image_name + "_" + str(idx) 87 | logger.add_image(image_name, vutils.make_grid(value, padding=0, nrow=1, normalize=True, scale_each=True), 88 | global_step) 89 | 90 | 91 | def adjust_learning_rate(optimizer, epoch, base_lr, lrepochs): 92 | splits = lrepochs.split(':') 93 | assert len(splits) == 2 94 | 95 | # parse the epochs to downscale the learning rate (before :) 96 | downscale_epochs = [int(eid_str) for eid_str in splits[0].split(',')] 97 | # parse downscale rate (after :) 98 | downscale_rate = float(splits[1]) 99 | print("downscale epochs: {}, downscale rate: {}".format(downscale_epochs, downscale_rate)) 100 | 101 | lr = base_lr 102 | for eid in downscale_epochs: 103 | if epoch >= eid: 104 | lr /= downscale_rate 105 | else: 106 | break 107 | print("setting learning rate to {}".format(lr)) 108 | for param_group in optimizer.param_groups: 109 | param_group['lr'] = lr 110 | 111 | 112 | class AverageMeter(object): 113 | def __init__(self): 114 | self.sum_value = 0. 115 | self.count = 0 116 | 117 | def update(self, x): 118 | check_allfloat(x) 119 | self.sum_value += x 120 | self.count += 1 121 | 122 | def mean(self): 123 | return self.sum_value / self.count 124 | 125 | 126 | class AverageMeterDict(object): 127 | def __init__(self): 128 | self.data = None 129 | self.count = 0 130 | 131 | def update(self, x): 132 | check_allfloat(x) 133 | self.count += 1 134 | if self.data is None: 135 | self.data = copy.deepcopy(x) 136 | else: 137 | for k1, v1 in x.items(): 138 | if isinstance(v1, float): 139 | self.data[k1] += v1 140 | elif isinstance(v1, tuple) or isinstance(v1, list): 141 | for idx, v2 in enumerate(v1): 142 | self.data[k1][idx] += v2 143 | else: 144 | assert NotImplementedError("error input type for update AvgMeterDict") 145 | 146 | def mean(self): 147 | @make_iterative_func 148 | def get_mean(v): 149 | return v / float(self.count) 150 | 151 | return get_mean(self.data) 152 | 153 | 154 | import torch.distributed as dist 155 | def get_world_size(): 156 | if not dist.is_available(): 157 | return 1 158 | if not dist.is_initialized(): 159 | return 1 160 | return dist.get_world_size() 161 | 162 | 163 | from collections import defaultdict 164 | def reduce_scalar_outputs(scalar_outputs): 165 | world_size = get_world_size() 166 | if world_size < 2: 167 | return scalar_outputs 168 | with torch.no_grad(): 169 | names = [] 170 | scalars = [] 171 | for k in sorted(scalar_outputs.keys()): 172 | if isinstance(scalar_outputs[k], (list, tuple)): 173 | for sub_var in scalar_outputs[k]: 174 | names.append(k) 175 | scalars.append(sub_var) 176 | else: 177 | names.append(k) 178 | scalars.append(scalar_outputs[k]) 179 | 180 | scalars = torch.stack(scalars, dim=0) 181 | dist.reduce(scalars, dst=0) 182 | if dist.get_rank() == 0: 183 | # only main process gets accumulated, so only divide by 184 | # world_size in this case 185 | scalars /= world_size 186 | 187 | reduced_scalars = defaultdict(list) 188 | for name, scalar in zip(names, scalars): 189 | reduced_scalars[name].append(scalar) 190 | 191 | return dict(reduced_scalars) -------------------------------------------------------------------------------- /SceneFlow/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from utils.experiment import make_nograd_func 4 | from torch.autograd import Variable 5 | from torch import Tensor 6 | 7 | 8 | # Update D1 from >3px to >=3px & >5% 9 | # matlab code: 10 | # E = abs(D_gt - D_est); 11 | # n_err = length(find(D_gt > 0 & E > tau(1) & E. / abs(D_gt) > tau(2))); 12 | # n_total = length(find(D_gt > 0)); 13 | # d_err = n_err / n_total; 14 | 15 | def check_shape_for_metric_computation(*vars): 16 | assert isinstance(vars, tuple) 17 | for var in vars: 18 | assert len(var.size()) == 3 19 | assert var.size() == vars[0].size() 20 | 21 | # a wrapper to compute metrics for each image individually 22 | def compute_metric_for_each_image(metric_func): 23 | def wrapper(D_ests, D_gts, masks, *nargs): 24 | check_shape_for_metric_computation(D_ests, D_gts, masks) 25 | bn = D_gts.shape[0] # batch size 26 | results = [] # a list to store results for each image 27 | # compute result one by one 28 | for idx in range(bn): 29 | # if tensor, then pick idx, else pass the same value 30 | cur_nargs = [x[idx] if isinstance(x, (Tensor, Variable)) else x for x in nargs] 31 | if masks[idx].float().mean() / (D_gts[idx] > 0).float().mean() < 0.1: 32 | print("masks[idx].float().mean() too small, skip") 33 | else: 34 | ret = metric_func(D_ests[idx], D_gts[idx], masks[idx], *cur_nargs) 35 | results.append(ret) 36 | if len(results) == 0: 37 | print("masks[idx].float().mean() too small for all images in this batch, return 0") 38 | return torch.tensor(0, dtype=torch.float32, device=D_gts.device) 39 | else: 40 | return torch.stack(results).mean() 41 | return wrapper 42 | 43 | @make_nograd_func 44 | @compute_metric_for_each_image 45 | def D1_metric(D_est, D_gt, mask): 46 | D_est, D_gt = D_est[mask], D_gt[mask] 47 | E = torch.abs(D_gt - D_est) 48 | err_mask = (E > 3) & (E / D_gt.abs() > 0.05) 49 | return torch.mean(err_mask.float()) 50 | 51 | @make_nograd_func 52 | @compute_metric_for_each_image 53 | def Thres_metric(D_est, D_gt, mask, thres): 54 | assert isinstance(thres, (int, float)) 55 | D_est, D_gt = D_est[mask], D_gt[mask] 56 | E = torch.abs(D_gt - D_est) 57 | err_mask = E > thres 58 | return torch.mean(err_mask.float()) 59 | 60 | # NOTE: please do not use this to build up training loss 61 | @make_nograd_func 62 | @compute_metric_for_each_image 63 | def EPE_metric(D_est, D_gt, mask): 64 | D_est, D_gt = D_est[mask], D_gt[mask] 65 | return F.l1_loss(D_est, D_gt, size_average=True) 66 | 67 | 68 | 69 | @make_nograd_func 70 | @compute_metric_for_each_image 71 | def D1_metric_mask(D_est, D_gt, mask, mask_img): 72 | # D_est, D_gt = D_est[(mask&mask_img)], D_gt[(mask&mask_img)] 73 | D_est, D_gt = D_est[mask_img], D_gt[mask_img] 74 | E = torch.abs(D_gt - D_est) 75 | err_mask = (E > 3) & (E / D_gt.abs() > 0.05) 76 | return torch.mean(err_mask.float()) 77 | 78 | @make_nograd_func 79 | @compute_metric_for_each_image 80 | def Thres_metric_mask(D_est, D_gt, mask, thres, mask_img): 81 | assert isinstance(thres, (int, float)) 82 | # D_est, D_gt = D_est[(mask&mask_img)], D_gt[(mask&mask_img)] 83 | D_est, D_gt = D_est[mask_img], D_gt[mask_img] 84 | E = torch.abs(D_gt - D_est) 85 | err_mask = E > thres 86 | return torch.mean(err_mask.float()) 87 | 88 | # NOTE: please do not use this to build up training loss 89 | @make_nograd_func 90 | @compute_metric_for_each_image 91 | def EPE_metric_mask(D_est, D_gt, mask, mask_img): 92 | # print((mask&mask_img).size(), D_est.size(), mask, mask_img) 93 | # D_est, D_gt = D_est[(mask&mask_img)], D_gt[(mask&mask_img)] 94 | D_est, D_gt = D_est[mask_img], D_gt[mask_img] 95 | return F.l1_loss(D_est, D_gt, size_average=True) 96 | 97 | -------------------------------------------------------------------------------- /SceneFlow/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def setup_for_distributed(is_master): 6 | """ 7 | This function disables printing when not in master process 8 | """ 9 | import builtins as __builtin__ 10 | builtin_print = __builtin__.print 11 | 12 | def print(*args, **kwargs): 13 | force = kwargs.pop('force', False) 14 | if is_master or force: 15 | builtin_print(*args, **kwargs) 16 | 17 | __builtin__.print = print 18 | 19 | 20 | def init_distributed_mode(args): 21 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 22 | args.rank = int(os.environ["RANK"]) 23 | args.world_size = int(os.environ['WORLD_SIZE']) 24 | args.gpu = args.local_rank 25 | args.dist_url = 'env://' 26 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) 27 | else: 28 | print('Not using distributed mode') 29 | args.distributed = False 30 | return 31 | 32 | args.distributed = True 33 | 34 | torch.cuda.set_device(args.gpu) 35 | args.dist_backend = 'nccl' 36 | print('| distributed init (rank {}): {}'.format( 37 | args.rank, args.dist_url), flush=True) 38 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 39 | world_size=args.world_size, rank=args.rank) 40 | torch.distributed.barrier() 41 | setup_for_distributed(args.rank == 0) 42 | -------------------------------------------------------------------------------- /SceneFlow/utils/visualization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.data 5 | from torch.autograd import Variable, Function 6 | import torch.nn.functional as F 7 | import math 8 | import numpy as np 9 | 10 | 11 | def gen_error_colormap(): 12 | cols = np.array( 13 | [[0 / 3.0, 0.1875 / 3.0, 49, 54, 149], 14 | [0.1875 / 3.0, 0.375 / 3.0, 69, 117, 180], 15 | [0.375 / 3.0, 0.75 / 3.0, 116, 173, 209], 16 | [0.75 / 3.0, 1.5 / 3.0, 171, 217, 233], 17 | [1.5 / 3.0, 3 / 3.0, 224, 243, 248], 18 | [3 / 3.0, 6 / 3.0, 254, 224, 144], 19 | [6 / 3.0, 12 / 3.0, 253, 174, 97], 20 | [12 / 3.0, 24 / 3.0, 244, 109, 67], 21 | [24 / 3.0, 48 / 3.0, 215, 48, 39], 22 | [48 / 3.0, np.inf, 165, 0, 38]], dtype=np.float32) 23 | cols[:, 2: 5] /= 255. 24 | return cols 25 | 26 | 27 | error_colormap = gen_error_colormap() 28 | 29 | 30 | class disp_error_image_func(Function): 31 | def forward(self, D_est_tensor, D_gt_tensor, abs_thres=3., rel_thres=0.05, dilate_radius=1): 32 | D_gt_np = D_gt_tensor.detach().cpu().numpy() 33 | D_est_np = D_est_tensor.detach().cpu().numpy() 34 | B, H, W = D_gt_np.shape 35 | # valid mask 36 | mask = D_gt_np > 0 37 | # error in percentage. When error <= 1, the pixel is valid since <= 3px & 5% 38 | error = np.abs(D_gt_np - D_est_np) 39 | error[np.logical_not(mask)] = 0 40 | error[mask] = np.minimum(error[mask] / abs_thres, (error[mask] / D_gt_np[mask]) / rel_thres) 41 | # get colormap 42 | cols = error_colormap 43 | # create error image 44 | error_image = np.zeros([B, H, W, 3], dtype=np.float32) 45 | for i in range(cols.shape[0]): 46 | error_image[np.logical_and(error >= cols[i][0], error < cols[i][1])] = cols[i, 2:] 47 | # TODO: imdilate 48 | # error_image = cv2.imdilate(D_err, strel('disk', dilate_radius)); 49 | error_image[np.logical_not(mask)] = 0. 50 | # show color tag in the top-left cornor of the image 51 | for i in range(cols.shape[0]): 52 | distance = 20 53 | error_image[:, :10, i * distance:(i + 1) * distance, :] = cols[i, 2:] 54 | 55 | return torch.from_numpy(np.ascontiguousarray(error_image.transpose([0, 3, 1, 2]))) 56 | 57 | def backward(self, grad_output): 58 | return None 59 | --------------------------------------------------------------------------------