├── models ├── __init__.py ├── models.py └── submodules.py ├── utils ├── __init__.py ├── utils.py └── logger.py ├── dataloader ├── __init__.py ├── readpfm.py ├── kitti2015load.py ├── dataloader.py └── sceneflow.py ├── .gitignore ├── reference ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── raw.png ├── left_test.png ├── right_test.png └── network_structure.png ├── val_set.txt ├── LICENSE ├── paddle_env.yml ├── inference.py ├── README.md ├── train.py └── finetune.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | 3 | __pycache__/ 4 | 5 | results/ 6 | 7 | log/ 8 | 9 | dataset/ -------------------------------------------------------------------------------- /reference/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrinceVictor/LWSNet/HEAD/reference/1.png -------------------------------------------------------------------------------- /reference/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrinceVictor/LWSNet/HEAD/reference/2.png -------------------------------------------------------------------------------- /reference/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrinceVictor/LWSNet/HEAD/reference/3.png -------------------------------------------------------------------------------- /reference/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrinceVictor/LWSNet/HEAD/reference/4.png -------------------------------------------------------------------------------- /reference/raw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrinceVictor/LWSNet/HEAD/reference/raw.png -------------------------------------------------------------------------------- /reference/left_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrinceVictor/LWSNet/HEAD/reference/left_test.png -------------------------------------------------------------------------------- /reference/right_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrinceVictor/LWSNet/HEAD/reference/right_test.png -------------------------------------------------------------------------------- /reference/network_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrinceVictor/LWSNet/HEAD/reference/network_structure.png -------------------------------------------------------------------------------- /val_set.txt: -------------------------------------------------------------------------------- 1 | 13 2 | 32 3 | 36 4 | 37 5 | 38 6 | 43 7 | 46 8 | 54 9 | 58 10 | 62 11 | 75 12 | 76 13 | 79 14 | 82 15 | 92 16 | 93 17 | 99 18 | 106 19 | 108 20 | 114 21 | 115 22 | 117 23 | 124 24 | 131 25 | 135 26 | 138 27 | 139 28 | 141 29 | 144 30 | 148 31 | 159 32 | 162 33 | 164 34 | 167 35 | 176 36 | 179 37 | 182 38 | 192 39 | 193 40 | 199 41 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.reset() 6 | 7 | def reset(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def update(self, val, n=1): 14 | self.val = val 15 | self.sum += val * n 16 | self.count += n 17 | self.avg = self.sum / self.count 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 PrinceVictor 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dataloader/readpfm.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import sys 4 | 5 | 6 | def readPFM(file): 7 | file = open(file, 'rb') 8 | 9 | color = None 10 | width = None 11 | height = None 12 | scale = None 13 | endian = None 14 | 15 | header = file.readline().rstrip() 16 | if header == b'PF': 17 | color = True 18 | elif header == b'Pf': 19 | color = False 20 | else: 21 | raise Exception('Not a PFM file.') 22 | 23 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 24 | if dim_match: 25 | width, height = map(int, dim_match.groups()) 26 | else: 27 | raise Exception('Malformed PFM header.') 28 | 29 | scale = float(file.readline().rstrip()) 30 | if scale < 0: # little-endian 31 | endian = '<' 32 | scale = -scale 33 | else: 34 | endian = '>' # big-endian 35 | 36 | data = np.fromfile(file, endian + 'f') 37 | shape = (height, width, 3) if color else (height, width) 38 | 39 | data = np.reshape(data, shape) 40 | data = np.flipud(data) 41 | file.close() 42 | return data, scale 43 | -------------------------------------------------------------------------------- /dataloader/kitti2015load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | 4 | import numpy as np 5 | 6 | def dataloader(filepath, split_file=None): 7 | left_fold = 'image_2/' 8 | right_fold = 'image_3/' 9 | disp_L = 'disp_occ_0/' 10 | disp_R = 'disp_occ_1/' 11 | 12 | image = [img for img in os.listdir(filepath + left_fold) if img.find('_10') > -1] 13 | 14 | all_index = np.arange(200) 15 | if split_file is None: 16 | np.random.shuffle(all_index) 17 | vallist = all_index[:40] 18 | else: 19 | with open(split_file) as f: 20 | vallist = sorted([int(x.strip()) for x in f.readlines() if len(x) > 0]) 21 | 22 | val = ['{:06d}_10.png'.format(x) for x in vallist] 23 | train = [x for x in image if x not in val] 24 | 25 | left_train = [os.path.join(filepath, left_fold, img) for img in train] 26 | right_train = [os.path.join(filepath, right_fold, img) for img in train] 27 | disp_train_L = [os.path.join(filepath, disp_L, img) for img in train] 28 | # disp_train_R = [filepath+disp_R+img for img in train] 29 | 30 | left_val = [os.path.join(filepath, left_fold, img) for img in val] 31 | right_val = [os.path.join(filepath, right_fold, img) for img in val] 32 | disp_val_L = [os.path.join(filepath, disp_L, img) for img in val] 33 | # disp_val_R = [filepath+disp_R+img for img in val] 34 | 35 | return left_train, right_train, disp_train_L, left_val, right_val, disp_val_L -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import time 5 | 6 | 7 | def setup_logger(name, save_path=None): 8 | file_name = time.strftime('-%Y-%m-%d-%H-%M', time.localtime(time.time())) 9 | name = name.replace(".", "_") 10 | if os.path.dirname(name) == "": 11 | file_name = name + file_name + ".log" 12 | else: 13 | file_name = os.path.dirname(name) + file_name + ".log" 14 | 15 | log_formatter = logging.Formatter( 16 | "[%(asctime)s %(filename)s:%(lineno)s] %(levelname)s: %(message)s", 17 | datefmt='%Y-%m-%d %H:%M:%S') 18 | 19 | logger = logging.getLogger(file_name) 20 | stream_handler = logging.StreamHandler(stream=sys.stderr) 21 | stream_handler.setLevel(logging.DEBUG) 22 | stream_handler.setFormatter(log_formatter) 23 | logger.addHandler(stream_handler) 24 | 25 | if file_name in [h.name for h in logger.handlers]: 26 | return 27 | 28 | if save_path is not None: 29 | if os.path.dirname(save_path) != '': 30 | if not os.path.isdir(os.path.dirname(save_path)): 31 | os.makedirs(os.path.dirname(save_path)) 32 | 33 | file_handler = logging.FileHandler(os.path.join(save_path, file_name)) 34 | file_handler.set_name(file_name) 35 | file_handler.setLevel(logging.DEBUG) 36 | file_handler.setFormatter(log_formatter) 37 | 38 | logger.addHandler(file_handler) 39 | logger.setLevel(logging.DEBUG) 40 | 41 | return logger 42 | 43 | 44 | -------------------------------------------------------------------------------- /dataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | from paddle.io import Dataset 3 | from paddle.vision.transforms import Compose, Normalize, Transpose, ToTensor 4 | import os 5 | from PIL import Image 6 | from . import readpfm as rp 7 | import numpy as np 8 | import random 9 | 10 | imagenet_stats = {'mean': [0.485, 0.456, 0.406], 11 | 'std': [0.229, 0.224, 0.225]} 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | ] 17 | 18 | def default_loader(path): 19 | return Image.open(path).convert('RGB') 20 | 21 | def sceneflow_dispLoader(path): 22 | return rp.readPFM(path) 23 | 24 | def Kitti_dispLoader(path): 25 | return Image.open(path) 26 | 27 | class MyDataloader(Dataset): 28 | def __init__(self, left, right, left_disparity, training=True, loader=default_loader, kitti_set=True): 29 | super(MyDataloader, self).__init__() 30 | 31 | self.left = left 32 | self.right = right 33 | self.disp_L = left_disparity 34 | self.loader = loader 35 | self.training = training 36 | self.kitti_set = kitti_set 37 | if self.kitti_set: 38 | self.dploader = Kitti_dispLoader 39 | else: 40 | self.dploader = sceneflow_dispLoader 41 | 42 | self.transform = Compose([Transpose(), 43 | Normalize(mean=imagenet_stats["mean"], std=imagenet_stats["std"])]) 44 | # self.transform = ToTensor() 45 | 46 | def __getitem__(self, index): 47 | left = self.left[index] 48 | right = self.right[index] 49 | disp_L = self.disp_L[index] 50 | 51 | left_img = self.loader(left) 52 | right_img = self.loader(right) 53 | 54 | if self.kitti_set: 55 | dataL = self.dploader(disp_L) 56 | dataL = np.ascontiguousarray(dataL, dtype=np.float32) / 256 57 | else: 58 | dataL, scaleL = self.dploader(disp_L) 59 | dataL = np.ascontiguousarray(dataL, dtype=np.float32) 60 | 61 | if self.training: 62 | w, h = left_img.size 63 | th, tw = 256, 512 64 | 65 | x1 = random.randint(0, w - tw) 66 | y1 = random.randint(0, h - th) 67 | 68 | left_img = np.array(left_img.crop((x1, y1, x1 + tw, y1 + th)), dtype=np.float32)/255 69 | right_img = np.array(right_img.crop((x1, y1, x1 + tw, y1 + th)), dtype=np.float32)/255 70 | dataL = dataL[y1:y1 + th, x1:x1 + tw] 71 | 72 | left_img = self.transform(left_img) 73 | right_img = self.transform(right_img) 74 | 75 | return left_img, right_img, dataL 76 | 77 | else: 78 | w, h = left_img.size 79 | 80 | if self.kitti_set: 81 | left_img = np.array(left_img.crop((w - 1232, h - 368, w, h)), dtype=np.float32)/255 82 | right_img = np.array(right_img.crop((w - 1232, h - 368, w, h)), dtype=np.float32)/255 83 | dataL = dataL[h - 368:h, w-1232:w] 84 | else: 85 | left_img = np.array(left_img.crop((w - 960, h - 544, w, h)), dtype=np.float32)/255 86 | right_img = np.array(right_img.crop((w - 960, h - 544, w, h)), dtype=np.float32)/255 87 | # dataL = dataL[h - 544:h, w - 960:w] 88 | 89 | left_img = self.transform(left_img) 90 | right_img = self.transform(right_img) 91 | 92 | return left_img, right_img, dataL 93 | 94 | def __len__(self): 95 | return len(self.left) 96 | 97 | if __name__ == "__main__": 98 | 99 | print("this is dataloader.py") -------------------------------------------------------------------------------- /dataloader/sceneflow.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from PIL import Image 4 | from . import readpfm as rp 5 | 6 | IMG_EXTENSIONS = [ 7 | '.jpg', '.JPG', '.jpeg', '.JPEG', 8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 9 | ] 10 | 11 | def is_image_file(filename): 12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 13 | 14 | 15 | def default_loader(path): 16 | return Image.open(path).convert('RGB') 17 | 18 | 19 | def disparity_loader(path): 20 | return rp.readPFM(path) 21 | 22 | 23 | from PIL import Image 24 | import os 25 | import os.path 26 | 27 | IMG_EXTENSIONS = [ 28 | '.jpg', '.JPG', '.jpeg', '.JPEG', 29 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 30 | ] 31 | 32 | 33 | def is_image_file(filename): 34 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 35 | 36 | 37 | def dataloader(filepath): 38 | filepath += '/' 39 | classes = [d for d in os.listdir(filepath) if os.path.isdir(os.path.join(filepath, d))] 40 | image = [img for img in classes if img.find('frames_cleanpass') > -1] 41 | disp = [dsp for dsp in classes if dsp.find('disparity') > -1] 42 | 43 | monkaa_path = filepath + [x for x in image if 'monkaa' in x][0] 44 | monkaa_disp = filepath + [x for x in disp if 'monkaa' in x][0] 45 | 46 | monkaa_dir = os.listdir(monkaa_path) 47 | 48 | all_left_img = [] 49 | all_right_img = [] 50 | all_left_disp = [] 51 | test_left_img = [] 52 | test_right_img = [] 53 | test_left_disp = [] 54 | 55 | for dd in monkaa_dir: 56 | for im in os.listdir(monkaa_path + '/' + dd + '/left/'): 57 | if is_image_file(monkaa_path + '/' + dd + '/left/' + im): 58 | all_left_img.append(monkaa_path + '/' + dd + '/left/' + im) 59 | all_left_disp.append(monkaa_disp + '/' + dd + '/left/' + im.split(".")[0] + '.pfm') 60 | 61 | for im in os.listdir(monkaa_path + '/' + dd + '/right/'): 62 | if is_image_file(monkaa_path + '/' + dd + '/right/' + im): 63 | all_right_img.append(monkaa_path + '/' + dd + '/right/' + im) 64 | 65 | flying_path = filepath + [x for x in image if x == 'frames_cleanpass'][0] 66 | flying_disp = filepath + [x for x in disp if x == 'frames_disparity'][0] 67 | flying_dir = flying_path + '/TRAIN/' 68 | subdir = ['A', 'B', 'C'] 69 | 70 | for ss in subdir: 71 | flying = os.listdir(flying_dir + ss) 72 | 73 | for ff in flying: 74 | imm_l = os.listdir(flying_dir + ss + '/' + ff + '/left/') 75 | for im in imm_l: 76 | if is_image_file(flying_dir + ss + '/' + ff + '/left/' + im): 77 | all_left_img.append(flying_dir + ss + '/' + ff + '/left/' + im) 78 | 79 | all_left_disp.append(flying_disp + '/TRAIN/' + ss + '/' + ff + '/left/' + im.split(".")[0] + '.pfm') 80 | 81 | if is_image_file(flying_dir + ss + '/' + ff + '/right/' + im): 82 | all_right_img.append(flying_dir + ss + '/' + ff + '/right/' + im) 83 | 84 | flying_dir = flying_path + '/TEST/' 85 | 86 | subdir = ['A', 'B', 'C'] 87 | 88 | for ss in subdir: 89 | flying = os.listdir(flying_dir + ss) 90 | 91 | for ff in flying: 92 | imm_l = os.listdir(flying_dir + ss + '/' + ff + '/left/') 93 | for im in imm_l: 94 | if is_image_file(flying_dir + ss + '/' + ff + '/left/' + im): 95 | test_left_img.append(flying_dir + ss + '/' + ff + '/left/' + im) 96 | 97 | test_left_disp.append(flying_disp + '/TEST/' + ss + '/' + ff + '/left/' + im.split(".")[0] + '.pfm') 98 | 99 | if is_image_file(flying_dir + ss + '/' + ff + '/right/' + im): 100 | test_right_img.append(flying_dir + ss + '/' + ff + '/right/' + im) 101 | 102 | driving_dir = filepath + [x for x in image if 'driving' in x][0] + '/' 103 | driving_disp = filepath + [x for x in disp if 'driving' in x][0] 104 | 105 | subdir1 = ['15mm_focallength', '15mm_focallength'] 106 | subdir2 = ['scene_backwards', 'scene_forwards'] 107 | subdir3 = ['fast', 'slow'] 108 | 109 | for i in subdir1: 110 | for j in subdir2: 111 | for k in subdir3: 112 | imm_l = os.listdir(driving_dir + i + '/' + j + '/' + k + '/left/') 113 | for im in imm_l: 114 | if is_image_file(driving_dir + i + '/' + j + '/' + k + '/left/' + im): 115 | all_left_img.append(driving_dir + i + '/' + j + '/' + k + '/left/' + im) 116 | all_left_disp.append( 117 | driving_disp + '/' + i + '/' + j + '/' + k + '/left/' + im.split(".")[0] + '.pfm') 118 | 119 | if is_image_file(driving_dir + i + '/' + j + '/' + k + '/right/' + im): 120 | all_right_img.append(driving_dir + i + '/' + j + '/' + k + '/right/' + im) 121 | 122 | return all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp 123 | 124 | -------------------------------------------------------------------------------- /paddle_env.yml: -------------------------------------------------------------------------------- 1 | name: paddle 2 | channels: 3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/ 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r/ 10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/menpo/ 11 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/ 12 | dependencies: 13 | - _libgcc_mutex=0.1=conda_forge 14 | - _openmp_mutex=4.5=1_gnu 15 | - astor=0.8.1=pyh9f0ad1d_0 16 | - brotlipy=0.7.0=py38h8df0ef7_1001 17 | - bzip2=1.0.8=h516909a_3 18 | - c-ares=1.17.1=h36c2ea0_0 19 | - ca-certificates=2020.11.8=ha878542_0 20 | - cairo=1.16.0=h9f066cc_1006 21 | - certifi=2020.11.8=py38h578d9bd_0 22 | - cffi=1.14.4=py38ha312104_0 23 | - chardet=3.0.4=py38h924ce5b_1008 24 | - cryptography=3.2.1=py38h7699a38_0 25 | - cudatoolkit=10.2.89=h8f6ccaa_6 26 | - cudnn=7.6.5=cuda10.2_0 27 | - cycler=0.10.0=py_2 28 | - dbus=1.13.6=hfdff14a_1 29 | - decorator=4.4.2=py_0 30 | - expat=2.2.9=he1b5a44_2 31 | - ffmpeg=4.3.1=h3215721_1 32 | - fontconfig=2.13.1=h7e3eb15_1002 33 | - freetype=2.10.4=h7ca028e_0 34 | - fribidi=1.0.10=h36c2ea0_0 35 | - gast=0.4.0=pyh9f0ad1d_0 36 | - gettext=0.19.8.1=hf34092f_1004 37 | - glib=2.66.3=h58526e2_0 38 | - gmp=6.2.1=h58526e2_0 39 | - gnutls=3.6.13=h85f3911_1 40 | - graphite2=1.3.13=h58526e2_1001 41 | - graphviz=2.42.3=h0511662_0 42 | - gst-plugins-base=1.14.5=h0935bb2_2 43 | - gstreamer=1.14.5=h36ae1b5_2 44 | - harfbuzz=2.7.2=ha5b49bf_1 45 | - hdf5=1.10.6=nompi_h2750804_1111 46 | - icu=67.1=he1b5a44_0 47 | - idna=2.10=pyh9f0ad1d_0 48 | - jasper=1.900.1=h07fcdf6_1006 49 | - jpeg=9d=h36c2ea0_0 50 | - kiwisolver=1.3.1=py38h82cb98a_0 51 | - krb5=1.17.2=h926e7f8_0 52 | - lame=3.100=h14c3975_1001 53 | - lcms2=2.11=hcbb858e_1 54 | - ld_impl_linux-64=2.35.1=hed1e6ac_0 55 | - libblas=3.9.0=3_openblas 56 | - libcblas=3.9.0=3_openblas 57 | - libclang=10.0.1=default_hde54327_1 58 | - libcurl=7.71.1=hcdd3856_8 59 | - libedit=3.1.20191231=he28a2e2_2 60 | - libev=4.33=h516909a_1 61 | - libevent=2.1.10=hcdb4288_3 62 | - libffi=3.2.1=he1b5a44_1007 63 | - libgcc-ng=9.3.0=h5dbcf3e_17 64 | - libgfortran-ng=9.3.0=he4bcb1c_17 65 | - libgfortran5=9.3.0=he4bcb1c_17 66 | - libglib=2.66.3=hbe7bbb4_0 67 | - libgomp=9.3.0=h5dbcf3e_17 68 | - libiconv=1.16=h516909a_0 69 | - liblapack=3.9.0=3_openblas 70 | - liblapacke=3.9.0=3_openblas 71 | - libllvm10=10.0.1=he513fc3_3 72 | - libnghttp2=1.41.0=h8cfc5f6_2 73 | - libopenblas=0.3.12=pthreads_h4812303_1 74 | - libopencv=4.5.0=py38_3 75 | - libpng=1.6.37=h21135ba_2 76 | - libpq=12.3=h5513abc_2 77 | - libprotobuf=3.13.0.1=h8b12597_0 78 | - libssh2=1.9.0=hab1572f_5 79 | - libstdcxx-ng=9.3.0=h2ae2ef3_17 80 | - libtiff=4.1.0=h4f3a223_6 81 | - libtool=2.4.6=h58526e2_1007 82 | - libuuid=2.32.1=h14c3975_1000 83 | - libwebp-base=1.1.0=h36c2ea0_3 84 | - libxcb=1.13=h14c3975_1002 85 | - libxkbcommon=0.10.0=he1b5a44_0 86 | - libxml2=2.9.10=h68273f3_2 87 | - lz4-c=1.9.2=he1b5a44_3 88 | - matplotlib=3.3.2=0 89 | - matplotlib-base=3.3.2=py38h5c7f4ab_1 90 | - mysql-common=8.0.21=2 91 | - mysql-libs=8.0.21=hf3661c5_2 92 | - ncurses=6.2=h58526e2_4 93 | - nettle=3.6=he412f7d_0 94 | - nltk=3.4.4=py_0 95 | - nspr=4.29=he1b5a44_1 96 | - nss=3.59=h2c00c37_0 97 | - numpy=1.19.4=py38hf0fd68c_1 98 | - olefile=0.46=pyh9f0ad1d_1 99 | - opencv=4.5.0=py38_3 100 | - openh264=2.1.1=h8b12597_0 101 | - openssl=1.1.1h=h516909a_0 102 | - pango=1.42.4=h69149e4_5 103 | - pathlib=1.0.1=py38h32f6830_3 104 | - pcre=8.44=he1b5a44_0 105 | - pillow=8.0.1=py38h70fbd49_0 106 | - pip=20.2.4=py_0 107 | - pixman=0.40.0=h36c2ea0_0 108 | - protobuf=3.13.0.1=py38hadf7658_1 109 | - pthread-stubs=0.4=h36c2ea0_1001 110 | - py-cpuinfo=5.0.0=py_0 111 | - py-opencv=4.5.0=py38h81c977d_3 112 | - pycparser=2.20=pyh9f0ad1d_2 113 | - pyopenssl=19.1.0=py_1 114 | - pyparsing=2.4.7=pyh9f0ad1d_0 115 | - pysocks=1.7.1=py38h924ce5b_2 116 | - python=3.8.6=h852b56e_0_cpython 117 | - python-dateutil=2.8.1=py_0 118 | - python_abi=3.8=1_cp38 119 | - qt=5.12.9=h1f2b2cb_0 120 | - readline=8.0=he28a2e2_2 121 | - requests=2.25.0=pyhd3deb0d_0 122 | - scipy=1.5.3=py38hb2138dd_0 123 | - setuptools=49.6.0=py38h924ce5b_2 124 | - six=1.15.0=pyh9f0ad1d_0 125 | - sqlite=3.33.0=h4cf870e_1 126 | - tk=8.6.10=hed695b0_1 127 | - tornado=6.1=py38h25fe258_0 128 | - urllib3=1.25.11=py_0 129 | - wheel=0.35.1=pyh9f0ad1d_0 130 | - x264=1!152.20180806=h14c3975_0 131 | - xorg-kbproto=1.0.7=h14c3975_1002 132 | - xorg-libice=1.0.10=h516909a_0 133 | - xorg-libsm=1.2.3=h84519dc_1000 134 | - xorg-libx11=1.6.12=h516909a_0 135 | - xorg-libxau=1.0.9=h14c3975_0 136 | - xorg-libxdmcp=1.1.3=h516909a_0 137 | - xorg-libxext=1.3.4=h516909a_0 138 | - xorg-libxpm=3.5.13=h516909a_0 139 | - xorg-libxrender=0.9.10=h516909a_1002 140 | - xorg-libxt=1.1.5=h516909a_1003 141 | - xorg-renderproto=0.11.1=h14c3975_1002 142 | - xorg-xextproto=7.3.0=h14c3975_1002 143 | - xorg-xproto=7.0.31=h14c3975_1007 144 | - xz=5.2.5=h516909a_1 145 | - zlib=1.2.11=h516909a_1010 146 | - zstd=1.4.5=h6597ccf_2 147 | - pip: 148 | - objgraph==3.4.1 149 | - paddlepaddle-gpu==2.0.0rc0 150 | - prettytable==0.7 151 | prefix: /home/zhb/Documents/miniconda3/envs/paddle 152 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | from paddle.vision.transforms import Compose, Normalize, ToTensor 3 | 4 | import os 5 | import argparse 6 | import glob 7 | import numpy as np 8 | import cv2 9 | import shutil 10 | import time 11 | import PIL.Image as Image 12 | 13 | from dataloader.dataloader import imagenet_stats 14 | from models.models import LWSNet 15 | import utils.logger as logger 16 | 17 | parser = argparse.ArgumentParser(description='Model Inference') 18 | parser.add_argument('--max_disparity', type=int, default=192) 19 | parser.add_argument('--img_path', type=str, default="dataset/kitti2015/testing/") 20 | parser.add_argument('--left_img', type=str, default="") 21 | parser.add_argument('--model', type=str, default="results/finetune/checkpoint.pdparams") 22 | parser.add_argument('--save_path', type=str, default="results/inference") 23 | parser.add_argument('--maxdisplist', type=int, nargs='+', default=[24, 5, 5]) 24 | parser.add_argument('--channels_3d', type=int, default=8, help='number of initial channels 3d feature extractor ') 25 | parser.add_argument('--layers_3d', type=int, default=4, help='number of initial layers in 3d network') 26 | parser.add_argument('--growth_rate', type=int, nargs='+', default=[4,1,1], help='growth rate in the 3d network') 27 | parser.add_argument('--gpu_id', type=int, default=0) 28 | parser.add_argument('--vis', action='store_true', default=False, help="Show inference results") 29 | args = parser.parse_args() 30 | 31 | def main(): 32 | 33 | LOG = logger.setup_logger(__file__, "./log/") 34 | for key, value in vars(args).items(): 35 | LOG.info(str(key) + ': ' + str(value)) 36 | 37 | gpu_id = args.gpu_id 38 | place = paddle.set_device("gpu:" + str(gpu_id)) 39 | 40 | model = LWSNet(args) 41 | if not os.path.isfile(args.model): 42 | LOG.info("No model load") 43 | raise SystemExit 44 | else: 45 | model.set_state_dict(paddle.load(args.model)) 46 | LOG.info("Successful load model") 47 | 48 | model.eval() 49 | 50 | if not args.left_img: 51 | if os.path.isdir(args.img_path): 52 | left_imgs_path = sorted(glob.glob(args.img_path + "image_2/*.png")) 53 | right_imgs_path = sorted(glob.glob(args.img_path + "image_3/*.png")) 54 | elif os.path.isfile(args.img_path): 55 | temp_path, img_name = args.img_path.split("/")[0:-2], args.img_path.split("/")[-1] 56 | temp_path = "/".join(temp_path) 57 | left_imgs_path = [os.path.join(temp_path, "image_2/"+img_name)] 58 | right_imgs_path = [os.path.join(temp_path, "image_3/"+img_name)] 59 | LOG.info("Load data path") 60 | 61 | if os.path.exists(args.save_path): 62 | shutil.rmtree(args.save_path) 63 | os.makedirs(args.save_path) 64 | LOG.info("Clear all files in the path: {}".format(args.save_path)) 65 | 66 | else : 67 | temp_path = args.left_img.split("/")[0:-1] 68 | temp_path = "/".join(temp_path) 69 | left_imgs_path = [args.left_img] 70 | right_imgs_path = [os.path.join(temp_path, "right_test.png")] 71 | 72 | LOG.info("Begin inference!") 73 | 74 | inference(model, left_imgs_path, right_imgs_path, LOG) 75 | 76 | LOG.info("End inference!") 77 | 78 | def inference(model, left_imgs, right_ims, LOG): 79 | 80 | stages = 4 81 | model.eval() 82 | 83 | transform = Compose([ToTensor(), 84 | Normalize(mean=imagenet_stats["mean"], 85 | std=imagenet_stats["std"])]) 86 | 87 | for index in range(len(left_imgs)): 88 | # LOG.info("left = {}\tright = {}".format(left_imgs[index], right_ims[index])) 89 | 90 | left_img = cv2.imread(left_imgs[index], cv2.IMREAD_UNCHANGED) 91 | right_img = cv2.imread(right_ims[index], cv2.IMREAD_UNCHANGED) 92 | 93 | h, w, c = left_img.shape 94 | th, tw = 368, 1232 95 | 96 | if h
Raw left image
139 | 140 |  141 | 142 |Stage 1
143 | 144 |  145 | 146 |Stage 2
147 | 148 |  149 | 150 |Stage 3
151 | 152 |  153 | 154 |Stage 4
155 | 156 |  157 | 158 | 159 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | import paddle.nn as nn 3 | import paddle.nn.functional as F 4 | 5 | from models.submodules import feature_extraction, post_3dconvs, refinement1, refinement2 6 | 7 | class LWSNet(nn.Layer): 8 | def __init__(self, args): 9 | super(LWSNet, self).__init__() 10 | 11 | self.maxdisplist = args.maxdisplist 12 | self.layers_3d = args.layers_3d 13 | self.channels_3d = args.channels_3d 14 | self.growth_rate = args.growth_rate 15 | 16 | self.feature_extraction = feature_extraction() 17 | self.volume_postprocess = [] 18 | 19 | for i in range(3): 20 | net3d = post_3dconvs(self.layers_3d, self.channels_3d*self.growth_rate[i]) 21 | self.volume_postprocess.append(net3d) 22 | self.volume_postprocess = nn.LayerList(self.volume_postprocess) #3D CNN in Stage 1 to Stage 3 23 | 24 | self.refinement1_left = refinement1(in_channels=3, out_channels=32) #input: left image output: left features 25 | self.refinement1_disp = refinement1(in_channels=1, out_channels=32) #input: disparity stage 3 output: disparity features 26 | self.refinement2 = refinement2(in_channels=64, out_channels=32) 27 | 28 | def warp(self, x, disp): 29 | """ 30 | warp an image/tensor (im2) back to im1, according to the optical flow 31 | x: [B, C, H, W] (im2) 32 | disp: [B, 1, H, W] 33 | flo: [B, 2, H, W] flow 34 | output: [B, C, H, W] (im1) 35 | """ 36 | B, C, H, W = x.shape 37 | # mesh grid 38 | xx = paddle.expand(paddle.arange(0, W, step=1, dtype='float32').reshape(shape=[1, -1]), shape=[H, W]) 39 | yy = paddle.expand(paddle.arange(0, H, step=1, dtype='float32').reshape(shape=[-1, 1]), shape=[H, W]) 40 | 41 | xx = paddle.expand(xx.reshape(shape=[1, 1, H, W]), shape=[B, 1, H, W]) 42 | yy = paddle.expand(yy.reshape(shape=[1, 1, H, W]), shape=[B, 1, H, W]) 43 | 44 | vgrid = paddle.concat((xx, yy), axis=1) #[B, 2, H, W] 45 | vgrid[:, :1, :, :] = vgrid[:, :1, :, :] - disp 46 | # scale grid to [-1,1] 47 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :] / max(W - 1, 1) - 1.0 48 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :] / max(H - 1, 1) - 1.0 49 | 50 | vgrid = paddle.transpose(vgrid, [0, 2, 3, 1]) #[B, H, W, 2] 51 | vgrid.stop_gradient = False 52 | 53 | output = F.grid_sample(x, vgrid) 54 | 55 | return output 56 | 57 | 58 | def _build_volume_2d(self, feat_l, feat_r, maxdisp, stride=1): 59 | """ 60 | output full disparity map 61 | L1 distance-based cost 62 | """ 63 | assert maxdisp % stride == 0 64 | 65 | cost = paddle.zeros((feat_l.shape[0], maxdisp // stride, feat_l.shape[2], feat_l.shape[3]), dtype='float32') 66 | cost.stop_gradient=False 67 | 68 | for i in range(0, maxdisp, stride): 69 | 70 | if i > 0: 71 | cost[:, i // stride, :, :i] = feat_l[:, :, :, :i].abs().sum(axis=1) #occlusion regions 72 | cost[:, i // stride, :, i:] = paddle.norm(feat_l[:, :, :, i:] - feat_r[:, :, :, :-i], 1, 1) 73 | else: 74 | cost[:, i // stride, :, i:] = paddle.norm(feat_l[:, :, :, :] - feat_r[:, :, :, :], 1, 1) 75 | 76 | return cost 77 | 78 | def _build_volume_2d3(self, feat_l, feat_r, maxdisp, disp, stride=1): 79 | """ 80 | output residual map 81 | L1 distance-based cost 82 | """ 83 | size = feat_l.shape 84 | 85 | disp = paddle.unsqueeze(disp, axis=1) 86 | batch_disp = paddle.expand(disp, shape=[disp.shape[0], maxdisp * 2 - 1, disp.shape[-3], disp.shape[-2], 87 | disp.shape[-1]]) 88 | batch_disp = batch_disp.reshape(shape=[-1, 1, size[-2], size[-1]]) 89 | 90 | batch_shift = paddle.arange(-maxdisp + 1, maxdisp, dtype="float32") 91 | batch_shift = paddle.expand(batch_shift, shape=[size[0], batch_shift.shape[0]]).reshape(shape=[-1]).unsqueeze( 92 | axis=[1, 2, 3]) * stride 93 | batch_disp = batch_disp - batch_shift 94 | batch_feat_l = paddle.unsqueeze(feat_l, axis=1).expand( 95 | shape=[size[0], maxdisp * 2 - 1, size[-3], size[-2], size[-1]]).reshape( 96 | shape=[-1, size[-3], size[-2], size[-1]]) 97 | batch_feat_r = paddle.unsqueeze(feat_r, axis=1).expand( 98 | shape=[size[0], maxdisp * 2 - 1, size[-3], size[-2], size[-1]]).reshape( 99 | shape=[-1, size[-3], size[-2], size[-1]]) 100 | 101 | cost = paddle.norm(batch_feat_l - self.warp(batch_feat_r, batch_disp), 1, 1) #output residual map 102 | cost = cost.reshape(shape=[size[0], -1, size[2], size[3]]) 103 | 104 | return cost 105 | 106 | def forward(self, left_input, right_input): 107 | 108 | img_size = left_input.shape 109 | 110 | feats_l = self.feature_extraction(left_input) #left features 111 | feats_r = self.feature_extraction(right_input) #right features 112 | 113 | pred = [] 114 | 115 | for scale in range(len(feats_l)): 116 | 117 | if scale > 0: 118 | #stage 2 and stage 3 119 | wflow = F.interpolate(pred[scale - 1], size=[feats_l[scale].shape[2], feats_l[scale].shape[3]], 120 | mode="bilinear") * \ 121 | feats_l[scale].shape[2] / img_size[2] #resize disparity of last stage to current resolution 122 | 123 | cost = self._build_volume_2d3(feats_l[scale], 124 | feats_r[scale], 125 | self.maxdisplist[scale], 126 | wflow, 127 | stride=1) #build cost volume 128 | 129 | else: 130 | #stage 1 131 | cost = self._build_volume_2d(feats_l[scale], 132 | feats_r[scale], 133 | self.maxdisplist[scale], 134 | stride=1) #build cost volume 135 | 136 | cost = paddle.unsqueeze(cost, [1]) 137 | cost = self.volume_postprocess[scale](cost) + cost #3D CNN, skip connection 138 | cost = paddle.squeeze(cost, [1]) 139 | 140 | if scale == 0: 141 | #stage 1 142 | pre_low_res = disparity_regression(start=0, end=self.maxdisplist[0])(input=F.softmax(-cost, axis=1)) #full disparity 143 | #softmax function computes the probability of a pixel's disparity to be d 144 | #'softmax(cost)' or 'softmax(-cost)' do not affect the performance because feature-based cost volume provided flexibility. 145 | pre_low_res = pre_low_res * img_size[2] / pre_low_res.shape[2] #transform disparity value to original resolution 146 | disp_up = F.interpolate(pre_low_res, size=[img_size[2], img_size[3]], mode="bilinear") #upsample to original resolution 147 | 148 | pred.append(disp_up) 149 | else: 150 | #stage 2 and 3 151 | pre_low_res = disparity_regression(start=-self.maxdisplist[scale] + 1, 152 | end=self.maxdisplist[scale])(input=F.softmax(-cost, axis=1)) #residual 153 | pre_low_res = pre_low_res * img_size[2] / pre_low_res.shape[2] #transform residual value to original resolution 154 | disp_up = F.interpolate(pre_low_res, size=[img_size[2], img_size[3]], mode="bilinear") #upsample to original resolution 155 | 156 | pred.append(disp_up + pred[scale - 1]) #skip connection 157 | 158 | refined_left = self.refinement1_left(left_input) 159 | refined_disp = self.refinement1_disp(pred[-1]) 160 | disp = self.refinement2(input=paddle.concat([refined_left, refined_disp], 1)) 161 | disp_up = F.interpolate(disp, size=[img_size[2], img_size[3]], mode="bilinear") 162 | pred.append(pred[2] + disp_up) #skip connection 163 | 164 | return pred #disparity maps of 4 stages 165 | 166 | 167 | class disparity_regression(nn.Layer): 168 | def __init__(self, start, end, stride=1): 169 | super(disparity_regression, self).__init__() 170 | self.disp = paddle.arange(start * stride, end * stride, stride, dtype='float32') 171 | self.disp.stop_gradient = True 172 | self.disp = paddle.reshape(self.disp, shape=[1, -1, 1, 1]) 173 | _, self.my_steplength, _, _ = self.disp.shape 174 | 175 | def forward(self, input): 176 | 177 | disp = paddle.expand(self.disp, (input.shape[0], self.my_steplength, input.shape[2], input.shape[3])) 178 | output = paddle.sum(input * disp, axis=1, keepdim=True) #compute expectation 179 | return output -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import paddle 3 | import paddle.nn.functional as F 4 | import paddle.fluid as fluid 5 | from paddle.io import DataLoader 6 | 7 | import numpy as np 8 | import os 9 | import glob 10 | import math 11 | import time 12 | 13 | from models.models import LWSNet 14 | from dataloader import sceneflow as sf 15 | from dataloader import dataloader 16 | import utils.logger as logger 17 | from utils.utils import AverageMeter as AverageMeter 18 | 19 | parser = argparse.ArgumentParser(description='pretrain Sceneflow main()') 20 | 21 | parser.add_argument('--maxdisp', type=int, default=192, 22 | help='maxium disparity') 23 | parser.add_argument('--datapath', default='dataset/sceneflow/') 24 | parser.add_argument('--loss_weights', type=float, nargs='+', default=[0.25, 0.5, 1., 1.]) 25 | parser.add_argument('--max_disparity', type=int, default=192) 26 | parser.add_argument('--maxdisplist', type=int, nargs='+', default=[24, 5, 5]) 27 | parser.add_argument('--channels_3d', type=int, default=8, help='number of initial channels 3d feature extractor ') 28 | parser.add_argument('--layers_3d', type=int, default=4, help='number of initial layers in 3d network') 29 | parser.add_argument('--growth_rate', type=int, nargs='+', default=[4,1,1], help='growth rate in the 3d network') 30 | parser.add_argument('--lr', type=float, default=5e-4, help='learning rate') 31 | parser.add_argument('--epoch', type=int, default=10) 32 | parser.add_argument('--last_epoch', type=int, default=-1) 33 | parser.add_argument('--train_batch_size', type=int, default=8) 34 | parser.add_argument('--test_batch_size', type=int, default=8) 35 | parser.add_argument('--gpu_id', type=int, default=0) 36 | parser.add_argument('--save_path', type=str, default="results/pretrained/") 37 | parser.add_argument('--model', type=str, default="checkpoint") 38 | parser.add_argument('--resume', type=str, default="") 39 | args = parser.parse_args() 40 | 41 | def main(): 42 | 43 | # configuration logger 44 | LOG = logger.setup_logger(__file__, "./log/") 45 | for key, value in vars(args).items(): 46 | LOG.info(str(key) + ': ' + str(value)) 47 | 48 | LOG.info("pretrain Sceneflow main()") 49 | 50 | stages = 4 51 | gpu_id = args.gpu_id 52 | paddle.set_device("gpu:"+str(gpu_id)) 53 | 54 | # get train and test dataset path 55 | train_left_img, train_right_img, train_left_disp, test_left_img, test_right_img, test_left_disp = sf.dataloader(args.datapath) 56 | 57 | # train and test dataloader 58 | train_loader = paddle.io.DataLoader( 59 | dataloader.MyDataloader(train_left_img, train_right_img, train_left_disp, training=True, kitti_set=False), 60 | batch_size=args.train_batch_size, places=paddle.CUDAPlace(gpu_id), shuffle=True, drop_last=False, num_workers=2) 61 | test_loader = paddle.io.DataLoader( 62 | dataloader.MyDataloader(test_left_img, test_right_img, test_left_disp, training=False, kitti_set=False), 63 | batch_size=args.test_batch_size, places=paddle.CUDAPlace(gpu_id), shuffle=False, drop_last=False, num_workers=2) 64 | 65 | train_batch_len, test_batch_len = len(train_loader), len(test_loader) 66 | LOG.info("train batch_len {} test batch_len {}".format(train_batch_len, test_batch_len)) 67 | 68 | if not os.path.isdir(args.save_path): 69 | os.makedirs(args.save_path) 70 | save_filename = os.path.join(args.save_path, args.model) 71 | 72 | # load model 73 | model = LWSNet(args) 74 | 75 | last_epoch = 0 76 | error_check = math.inf 77 | start_time = time.time() 78 | 79 | # Setup optimizer 80 | optimizer = paddle.optimizer.Adam(learning_rate=args.lr, parameters=model.parameters()) 81 | 82 | if args.resume: 83 | if len(glob.glob(args.resume + "/*.pdparams")): 84 | model_state = paddle.load(glob.glob(args.resume + "/*.pdparams")[0]) 85 | model.set_state_dict(model_state) 86 | LOG.info("load model state") 87 | 88 | if len(glob.glob(args.resume + "/*.pdopt")): 89 | opt_state = paddle.load(glob.glob(args.resume + "/*.pdopt")[0]) 90 | optimizer.set_state_dict(opt_state) 91 | LOG.info("load optimizer state") 92 | 93 | if len(glob.glob(args.resume + "/*.params")): 94 | param_state = paddle.load(glob.glob(args.resume + "/*.params")[0]) 95 | last_epoch = param_state["epoch"] + 1 96 | last_lr = param_state["lr"] 97 | error_check = param_state["error"] 98 | start_time = start_time - param_state["time_cost"] 99 | LOG.info("load last epoch = {}\tlr = {:.5f}\terror = {:.4f}\ttime_cost = {:.2f} Hours" 100 | .format(last_epoch, last_lr, error_check, param_state["time_cost"] / 3600)) 101 | 102 | LOG.info("resume successfully") 103 | 104 | if args.last_epoch != -1: 105 | last_epoch = args.last_epoch 106 | 107 | for epoch in range(last_epoch, args.epoch): 108 | 109 | train(model, train_loader, optimizer, epoch, LOG) 110 | error = test(model, test_loader, epoch, LOG) 111 | 112 | if error < error_check: 113 | error_check = error 114 | 115 | paddle.save(model.state_dict(), save_filename + ".pdparams") 116 | paddle.save(optimizer.state_dict(), save_filename + ".pdopt") 117 | paddle.save({"epoch": epoch, 118 | "lr": optimizer.get_lr(), 119 | "error": error_check, 120 | "time_cost": time.time()-start_time}, 121 | save_filename + ".params") 122 | LOG.info("save model param success") 123 | 124 | LOG.info('full training time = {:.2f} Hours'.format((time.time() - start_time) / 3600)) 125 | 126 | # Train function 127 | def train(model, data_loader, optimizer, epoch, LOG): 128 | 129 | stages = 4 130 | losses = [AverageMeter() for _ in range(stages)] 131 | length_loader = len(data_loader) 132 | model.train() 133 | 134 | for batch_id, data in enumerate(data_loader()): 135 | left_img, right_img, gt = data 136 | 137 | mask = paddle.to_tensor(gt.numpy() < args.maxdisp) 138 | gt_mask = paddle.masked_select(gt, mask) 139 | if paddle.cast(mask, "float32").sum() == 0: 140 | continue 141 | 142 | outputs = model(left_img, right_img) 143 | outputs = [paddle.squeeze(output) for output in outputs] 144 | 145 | stage_loss = [] 146 | for index in range(stages): 147 | loss = args.loss_weights[index] * F.smooth_l1_loss(paddle.masked_select(outputs[index], mask), 148 | gt_mask, reduction='mean') 149 | stage_loss.append(loss) 150 | losses[index].update(float(loss.numpy()) / args.loss_weights[index]) 151 | 152 | sum_loss = paddle.add_n(stage_loss) 153 | sum_loss.backward() 154 | optimizer.step() 155 | optimizer.clear_grad() 156 | 157 | if batch_id % 5 == 0: 158 | info_str = ['Stage {} = {:.2f}({:.2f})'.format(x, losses[x].val, losses[x].avg) for x in range(stages)] 159 | info_str = '\t'.join(info_str) 160 | 161 | LOG.info( 162 | 'Train Epoch{} [{}/{}] lr:{:.5f}\t{}'.format(epoch, batch_id, length_loader, optimizer.get_lr(), 163 | info_str)) 164 | 165 | info_str = '\t'.join(['Stage {} = {:.2f}'.format(x, losses[x].avg) for x in range(stages)]) 166 | LOG.info('Average train loss = ' + info_str) 167 | 168 | # Test function 169 | def test(model, data_loader, epoch, LOG): 170 | 171 | stages = 4 172 | EPEs = [AverageMeter() for _ in range(stages)] 173 | length_loader = len(data_loader) 174 | model.eval() 175 | 176 | for batch_id, data in enumerate(data_loader()): 177 | left_img, right_img, gt = data 178 | 179 | gt = gt.numpy() 180 | mask = gt < args.maxdisp 181 | 182 | with fluid.dygraph.no_grad(): 183 | outputs = model(left_img, right_img) 184 | 185 | for stage in range(stages): 186 | if len(gt[mask]) == 0: 187 | continue 188 | output = paddle.squeeze(outputs[stage], 1).numpy() 189 | output = output[:, 4:, :] 190 | EPEs[stage].update(float(np.mean(np.abs(output[mask] - gt[mask])))) 191 | 192 | if batch_id % 5 == 0: 193 | info_str = '\t'.join(['Stage {} = {:.2f}({:.2f})'.format(x, EPEs[x].val, EPEs[x].avg) for x in range(stages)]) 194 | LOG.info('Test: [{}/{}] {}'.format(batch_id, length_loader, info_str)) 195 | 196 | info_str = ', '.join(['Stage {}={:.2f}'.format(x, EPEs[x].avg) for x in range(stages)]) 197 | LOG.info('Average test EPE = ' + info_str) 198 | 199 | return EPEs[-1].avg 200 | 201 | 202 | def error_estimating(disp, ground_truth, maxdisp=192): 203 | gt = ground_truth 204 | # print(disp.shape, ground_truth.shape, np.max(gt), np.min(gt)) 205 | mask = gt > 0 206 | mask = mask * (gt < maxdisp) 207 | 208 | errmap = np.abs(disp - gt) 209 | err3 = ((errmap[mask] > 3.) & (errmap[mask] / gt[mask] > 0.05)).sum() 210 | return float(err3) / float(mask.sum() + 1e-9) 211 | 212 | 213 | if __name__ == "__main__": 214 | main() 215 | 216 | 217 | 218 | 219 | 220 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | import paddle.nn.functional as F 3 | from paddle.io import DataLoader 4 | 5 | import os 6 | import glob 7 | import math 8 | import numpy as np 9 | import argparse 10 | import time 11 | 12 | import utils.logger as logger 13 | from utils.utils import AverageMeter as AverageMeter 14 | from models.models import LWSNet 15 | from dataloader import kitti2015load as kitti 16 | from dataloader import dataloader 17 | 18 | parser = argparse.ArgumentParser(description='finetune KITTI') 19 | 20 | parser.add_argument('--maxdisp', type=int, default=192, 21 | help='maxium disparity') 22 | parser.add_argument('--datapath', default='dataset/kitti2015/training/', help='datapath') 23 | parser.add_argument('--loss_weights', type=float, nargs='+', default=[0.25, 0.5, 1., 1.]) 24 | parser.add_argument('--max_disparity', type=int, default=192) 25 | parser.add_argument('--maxdisplist', type=int, nargs='+', default=[24, 5, 5]) 26 | parser.add_argument('--channels_3d', type=int, default=8, help='number of initial channels 3d feature extractor ') 27 | parser.add_argument('--layers_3d', type=int, default=4, help='number of initial layers in 3d network') 28 | parser.add_argument('--growth_rate', type=int, nargs='+', default=[4,1,1], help='growth rate in the 3d network') 29 | parser.add_argument('--lr', type=float, default=5e-4, help='learning rate') 30 | parser.add_argument('--epoch', type=int, default=300) 31 | parser.add_argument('--last_epoch', type=int, default=-1) 32 | parser.add_argument('--train_batch_size', type=int, default=4) 33 | parser.add_argument('--test_batch_size', type=int, default=8) 34 | parser.add_argument('--gpu_id', type=int, default=0) 35 | parser.add_argument('--save_path', type=str, default="results/finetune") 36 | parser.add_argument('--model', type=str, default="checkpoint") 37 | parser.add_argument('--pretrained', type=str, default="results/pretrained") 38 | parser.add_argument('--resume', type=str, default="") 39 | parser.add_argument('--val_set', type=str, default='val_set.txt') 40 | parser.add_argument('--evaluate', action='store_true', default=False) 41 | args = parser.parse_args() 42 | 43 | def main(): 44 | 45 | # configuration logger 46 | LOG = logger.setup_logger(__file__, "./log/") 47 | for key, value in sorted(vars(args).items()): 48 | LOG.info(str(key) + ': ' + str(value)) 49 | 50 | LOG.info("finetune KITTI main()") 51 | 52 | gpu_id = args.gpu_id 53 | place = paddle.set_device("gpu:"+str(gpu_id)) 54 | 55 | # get train and test dataset path 56 | train_left_img, train_right_img, train_left_disp, \ 57 | test_left_img, test_right_img, test_left_disp = kitti.dataloader(args.datapath, args.val_set) 58 | 59 | # train and test dataloader 60 | train_loader = paddle.io.DataLoader( 61 | dataloader.MyDataloader(train_left_img, train_right_img, train_left_disp, training=True), 62 | batch_size=args.train_batch_size, places=place, shuffle=True, drop_last=False, num_workers=2) 63 | test_loader = paddle.io.DataLoader( 64 | dataloader.MyDataloader(test_left_img, test_right_img, test_left_disp, training=False), 65 | batch_size=args.test_batch_size, places=place, shuffle=False, drop_last=False, num_workers=2) 66 | 67 | train_batch_len, test_batch_len = len(train_loader), len(test_loader) 68 | LOG.info("train batch_len {} test batch_len {}".format(train_batch_len, test_batch_len)) 69 | 70 | if not os.path.isdir(args.save_path): 71 | os.makedirs(args.save_path) 72 | save_filename = os.path.join(args.save_path, args.model) 73 | 74 | # load model 75 | model = LWSNet(args) 76 | 77 | last_epoch = 0 78 | error_check = math.inf 79 | start_time = time.time() 80 | 81 | # Setup optimizer and learn rate scheduler 82 | milestones = [200, 400] 83 | lr_scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=args.lr, milestones=milestones, gamma=0.1) 84 | optimizer = paddle.optimizer.Adam(learning_rate=lr_scheduler, parameters=model.parameters()) 85 | 86 | # Load pretrained model or resume model from weight files 87 | if args.pretrained and not args.resume: 88 | if len(glob.glob(args.pretrained + "/*.pdparams")): 89 | model_state = paddle.load(glob.glob(args.pretrained + "/*.pdparams")[0]) 90 | model.set_state_dict(model_state) 91 | LOG.info("load pretrained model state") 92 | 93 | elif args.resume: 94 | if len(glob.glob(args.resume+"/*.pdparams")): 95 | model_state = paddle.load(glob.glob(args.resume+"/*.pdparams")[0]) 96 | model.set_state_dict(model_state) 97 | LOG.info("load model state") 98 | 99 | if len(glob.glob(args.resume+"/*.pdopt")): 100 | opt_state = paddle.load(glob.glob(args.resume+"/*.pdopt")[0]) 101 | optimizer.set_state_dict(opt_state) 102 | LOG.info("load optimizer state") 103 | 104 | if len(glob.glob(args.resume+"/*.params")): 105 | param_state = paddle.load(glob.glob(args.resume+"/*.params")[0]) 106 | last_epoch = param_state["epoch"] + 1 107 | last_lr = param_state["lr"] 108 | error_check = param_state["error"] 109 | start_time = start_time - param_state["time_cost"] 110 | LOG.info("load last epoch = {}\tlr = {:.5f}\terror = {:.4f}\ttime_cost = {:.2f} Hours" 111 | .format(last_epoch, last_lr, error_check, param_state["time_cost"]/36000)) 112 | 113 | LOG.info("resume successfully") 114 | 115 | if args.evaluate: 116 | test(model, test_loader, LOG) 117 | return 118 | 119 | if args.last_epoch != -1: 120 | last_epoch = args.last_epoch 121 | 122 | for epoch in range(last_epoch, args.epoch): 123 | 124 | train(model, train_loader, optimizer, lr_scheduler, epoch, LOG) 125 | error = test(model, test_loader, LOG) 126 | 127 | if error < error_check: 128 | error_check = error 129 | 130 | paddle.save(model.state_dict(), save_filename + ".pdparams") 131 | paddle.save(optimizer.state_dict(), save_filename + ".pdopt") 132 | paddle.save({"epoch": epoch, 133 | "lr": optimizer.get_lr(), 134 | "error": error_check, 135 | "time_cost": time.time()-start_time}, 136 | save_filename + ".params") 137 | LOG.info("save model param success") 138 | 139 | LOG.info('full training time = {:.2f} Hours'.format((time.time() - start_time) / 3600)) 140 | 141 | # Train function 142 | def train(model, data_loader, optimizer, lr_scheduler, epoch, LOG): 143 | 144 | stages = 4 145 | losses = [AverageMeter() for _ in range(stages)] 146 | length_loader = len(data_loader) 147 | 148 | model.train() 149 | 150 | for batch_id, data in enumerate(data_loader()): 151 | left_img, right_img, gt = data 152 | 153 | mask = paddle.to_tensor(gt.numpy() > 0) 154 | gt_mask = paddle.masked_select(gt, mask) 155 | 156 | outputs = model(left_img, right_img) 157 | outputs = [paddle.squeeze(output) for output in outputs] 158 | 159 | tem_stage_loss = [] 160 | for index in range(stages): 161 | temp_loss = args.loss_weights[index] * F.smooth_l1_loss(paddle.masked_select(outputs[index], mask), gt_mask, 162 | reduction='mean') 163 | tem_stage_loss.append(temp_loss) 164 | losses[index].update(float(temp_loss.numpy() / args.loss_weights[index])) 165 | 166 | sum_loss = paddle.add_n(tem_stage_loss) 167 | sum_loss.backward() 168 | optimizer.step() 169 | optimizer.clear_grad() 170 | 171 | if batch_id % 5 == 0: 172 | info_str = ['Stage {} = {:.2f}({:.2f})'.format(x, losses[x].val, losses[x].avg) for x in range(stages)] 173 | info_str = '\t'.join(info_str) 174 | info_str = 'Train Epoch{} [{}/{}] lr:{:.5f}\t{}'.format(epoch, batch_id, length_loader, optimizer.get_lr(), 175 | info_str) 176 | LOG.info(info_str) 177 | 178 | lr_scheduler.step() 179 | 180 | info_str = '\t'.join(['Stage {} = {:.2f}'.format(x, losses[x].avg) for x in range(stages)]) 181 | LOG.info('Average train loss: ' + info_str) 182 | 183 | # Test function 184 | def test(model, data_loader, LOG): 185 | 186 | stages = 4 187 | D1s = [AverageMeter() for _ in range(stages)] 188 | length_loader = len(data_loader) 189 | 190 | model.eval() 191 | 192 | for batch_id, data in enumerate(data_loader()): 193 | left_img, right_img, gt = data 194 | 195 | with paddle.no_grad(): 196 | outputs = model(left_img, right_img) 197 | outputs = [paddle.squeeze(output) for output in outputs] 198 | 199 | for stage in range(stages): 200 | output = paddle.squeeze(outputs[stage], 1) 201 | D1s[stage].update(error_estimating(output.numpy(), gt.numpy())) 202 | 203 | info_str = '\t'.join( 204 | ['Stage {} = {:.4f}({:.4f})'.format(x, D1s[x].val, D1s[x].avg) for x in range(stages)]) 205 | LOG.info('Test [{}/{}] {}'.format(batch_id, length_loader, info_str)) 206 | 207 | info_str = ', '.join(['Stage {}={:.4f}'.format(x, D1s[x].avg) for x in range(stages)]) 208 | LOG.info('Average test 3-Pixel Error: ' + info_str) 209 | 210 | return D1s[-1].avg 211 | 212 | def error_estimating(disp, ground_truth, maxdisp=192): 213 | gt = ground_truth 214 | mask = gt > 0 215 | mask = mask * (gt < maxdisp) 216 | 217 | errmap = np.abs(disp - gt) 218 | err3 = ((errmap[mask] > 3.) & (errmap[mask] / gt[mask] > 0.05)).sum() 219 | return float(err3) / float(mask.sum()) 220 | 221 | if __name__ == "__main__": 222 | 223 | main() 224 | 225 | 226 | -------------------------------------------------------------------------------- /models/submodules.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | import paddle.nn as nn 3 | import paddle.nn.functional as F 4 | 5 | def convbn(in_channels, out_channels, 6 | kernel_size, stride, padding, dilation=1, 7 | conv_param_attr=None, conv_bias_attr=None, 8 | bn_param_attr=None, bn_bias_attr=None): 9 | # 2D convolutional layer + batchnorm 10 | return nn.Sequential(nn.Conv2D(in_channels=in_channels, 11 | out_channels=out_channels, 12 | kernel_size=kernel_size, 13 | stride=stride, 14 | padding=dilation if dilation>1 else padding, 15 | dilation=dilation, 16 | weight_attr=conv_param_attr, 17 | bias_attr=conv_bias_attr), 18 | nn.BatchNorm2D(num_features=out_channels)) 19 | 20 | def deconvbn(in_channels, out_channels, 21 | kernel_size, stride, padding, output_padding=1, dilation=1, 22 | conv_param_attr=None, conv_bias_attr=None, 23 | bn_param_attr=None, bn_bias_attr=None): 24 | # 2D deconvolutional layer + batchnorm 25 | return nn.Sequential(nn.Conv2DTranspose(in_channels=in_channels, 26 | out_channels=out_channels, 27 | kernel_size=kernel_size, 28 | padding=padding, 29 | output_padding=output_padding, 30 | stride=stride, 31 | weight_attr=conv_param_attr, 32 | bias_attr=conv_bias_attr), 33 | nn.BatchNorm2D(num_features=out_channels)) 34 | 35 | class hourglass(nn.Layer): 36 | def __init__(self, init_channel=8): 37 | super(hourglass, self).__init__() 38 | self.init_channel = init_channel 39 | 40 | self.conv1 = nn.Sequential(convbn(in_channels=self.init_channel, 41 | out_channels=self.init_channel*2, 42 | kernel_size=3, 43 | stride=2, 44 | padding=1, 45 | conv_param_attr=nn.initializer.KaimingNormal(), 46 | conv_bias_attr=False), 47 | nn.ReLU()) 48 | 49 | self.conv2 = nn.Sequential(convbn(in_channels=self.init_channel * 2, 50 | out_channels=self.init_channel * 2, 51 | kernel_size=3, 52 | stride=1, 53 | padding=1, 54 | conv_param_attr=nn.initializer.KaimingNormal(), 55 | conv_bias_attr=False), 56 | nn.ReLU()) 57 | 58 | self.conv3 = nn.Sequential(convbn(in_channels=self.init_channel * 2, 59 | out_channels=self.init_channel * 2, 60 | kernel_size=3, 61 | stride=2, 62 | padding=1, 63 | conv_param_attr=nn.initializer.KaimingNormal(), 64 | conv_bias_attr=False), 65 | nn.ReLU()) 66 | 67 | self.conv4 = nn.Sequential(convbn(in_channels=self.init_channel * 2, 68 | out_channels=self.init_channel * 2, 69 | kernel_size=3, 70 | stride=1, 71 | padding=1, 72 | conv_param_attr=nn.initializer.KaimingNormal(), 73 | conv_bias_attr=False), 74 | nn.ReLU()) 75 | 76 | self.conv5 = deconvbn(in_channels=self.init_channel * 2, 77 | out_channels=self.init_channel * 2, 78 | kernel_size=3, 79 | padding=1, 80 | output_padding=1, 81 | stride=2, 82 | conv_param_attr=nn.initializer.KaimingNormal(), 83 | conv_bias_attr=False) #+conv2 84 | 85 | self.conv6 = deconvbn(in_channels=self.init_channel * 2, 86 | out_channels=self.init_channel, 87 | kernel_size=3, 88 | padding=1, 89 | output_padding=1, 90 | stride=2, 91 | conv_param_attr=nn.initializer.KaimingNormal(), 92 | conv_bias_attr=False) 93 | 94 | def forward(self, input): 95 | res = [] 96 | output = self.conv1(input) #in: 1/2 out: 1/4 channel: 8 to 16 97 | pre = self.conv2(output) #in: 1/4 out: 1/4 channel: 16 to 16 98 | 99 | output = self.conv3(pre) #in: 1/4 out: 1/8 channel: 16 to 16 100 | output = self.conv4(output) #in: 1/8 out: 1/8 channel: 16 to 16 101 | res.append(output) #feature maps(1/8) channel: 16 102 | 103 | post = F.relu(self.conv5(output)+pre) #in: 1/8 out: 1/4 channel: 16 to 16 104 | res.append(post) #feature maps(1/4) channel: 16 105 | 106 | output = self.conv6(post) #in: 1/4 out: 1/2 channel: 16 to 16 107 | res.append(output) #feature maps(1/2) channel: 8 108 | 109 | return res #feature maps(1/8, 1/4, 1/2) 110 | 111 | 112 | 113 | class feature_extraction(nn.Layer): 114 | 115 | def __init__(self): 116 | super(feature_extraction, self).__init__() 117 | 118 | self.dres0 = nn.Sequential(convbn(in_channels=3, 119 | out_channels=4, 120 | kernel_size=3, 121 | stride=2, 122 | padding=1, 123 | dilation=2, 124 | conv_param_attr=nn.initializer.KaimingNormal(), 125 | conv_bias_attr=False), 126 | nn.ReLU(), 127 | convbn(in_channels=4, 128 | out_channels=8, 129 | kernel_size=3, 130 | stride=1, 131 | padding=1, 132 | dilation=4, 133 | conv_param_attr=nn.initializer.KaimingNormal(), 134 | conv_bias_attr=False), 135 | nn.ReLU()) 136 | 137 | self.dres1 = nn.Sequential(convbn(in_channels=8, 138 | out_channels=4, 139 | kernel_size=3, 140 | stride=1, 141 | padding=1, 142 | dilation=2, 143 | conv_param_attr=nn.initializer.KaimingNormal(), 144 | conv_bias_attr=False), 145 | nn.ReLU(), 146 | convbn(in_channels=4, 147 | out_channels=8, 148 | kernel_size=3, 149 | stride=1, 150 | padding=1, 151 | dilation=2, 152 | conv_param_attr=nn.initializer.KaimingNormal(), 153 | conv_bias_attr=False)) 154 | 155 | self.dres2 = hourglass(init_channel=8) 156 | 157 | self.classif1 = nn.Sequential(convbn(in_channels=8, 158 | out_channels=8, 159 | kernel_size=3, 160 | stride=1, 161 | padding=1, 162 | dilation=1, 163 | conv_param_attr=nn.initializer.KaimingNormal(), 164 | conv_bias_attr=False), 165 | nn.ReLU(), 166 | nn.Conv2D(in_channels=8, 167 | out_channels=8, 168 | kernel_size=3, 169 | padding=1, 170 | stride=1, 171 | weight_attr=nn.initializer.KaimingNormal(), 172 | bias_attr=False)) 173 | 174 | 175 | 176 | def forward(self, input): 177 | 178 | output = self.dres0(input) #in: 1 out: 1/2 channel: 3 to 8 179 | output = self.dres1(output) + output #in: 1/2 out: 1/2 channel: 8 to 8 180 | 181 | res = self.dres2(output) #in: 1/2 out: 1/8, 1/4, 1/2 182 | output = res[-1] + output #skip connection 183 | 184 | output = self.classif1(output) #in: 1/2 out: 1/2 channel: 8 to 8 185 | res.pop(-1) 186 | res.append(output) #feature maps(1/2) channel: 8 187 | 188 | return res 189 | 190 | def batch_relu_conv3d(in_channels, out_channels, 191 | kernel_size=3, stride=1, padding=1, bn3d=True, 192 | conv_param_attr=nn.initializer.KaimingNormal(), conv_bias_attr=False, 193 | bn_param_attr=None, bn_bias_attr=None): 194 | if bn3d: 195 | # 3D batchnorm + relu + convolutional layer 196 | return nn.Sequential(nn.BatchNorm3D(num_features=in_channels), 197 | nn.ReLU(), 198 | nn.Conv3D(in_channels=in_channels, 199 | out_channels=out_channels, 200 | kernel_size=kernel_size, 201 | padding=padding, 202 | stride=stride, 203 | weight_attr=conv_param_attr, 204 | bias_attr=conv_bias_attr)) 205 | else: 206 | # 3D relu + convolutional layer 207 | return nn.Sequential(nn.ReLU(), 208 | nn.Conv3D(in_channels=in_channels, 209 | out_channels=out_channels, 210 | kernel_size=kernel_size, 211 | padding=padding, 212 | stride=stride, 213 | weight_attr=conv_param_attr, 214 | bias_attr=conv_bias_attr)) 215 | 216 | def post_3dconvs(layers, channels): 217 | #3D CNN applied on cost volume 218 | net = [batch_relu_conv3d(1, channels)] 219 | net = net+[batch_relu_conv3d(channels, channels) for _ in range(layers)] 220 | net = net+[batch_relu_conv3d(channels, 1)] 221 | return nn.Sequential(*net) 222 | 223 | def preconv2d(in_channels, out_channels, kernel_size, stride, pad, dilation=1, bn=True): 224 | if bn: 225 | # 2D batchnorm + relu + convolutional layer 226 | return nn.Sequential(nn.BatchNorm2D(num_features=in_channels), 227 | nn.ReLU(), 228 | nn.Conv2D(in_channels=in_channels, 229 | out_channels=out_channels, 230 | kernel_size=kernel_size, 231 | stride=stride, 232 | padding=dilation if dilation > 1 else pad, 233 | dilation=dilation, 234 | weight_attr=nn.initializer.KaimingNormal(), 235 | bias_attr=False)) 236 | 237 | 238 | def preconv2d_depthseperated(in_channels, out_channels, 239 | kernel_size, stride, pad, 240 | dilation=1, bn=True): 241 | #depthwise separable convolution 242 | if bn: 243 | # 2D batchnorm + relu + depthwise separable convolution 244 | return nn.Sequential(nn.BatchNorm2D(num_features=in_channels), 245 | nn.ReLU(), 246 | nn.Conv2D(in_channels=in_channels, 247 | out_channels=in_channels, 248 | kernel_size=kernel_size, 249 | stride=stride, 250 | padding=dilation if dilation > 1 else pad, 251 | dilation=dilation, 252 | weight_attr=nn.initializer.KaimingNormal(), 253 | bias_attr=False, 254 | groups=in_channels), 255 | nn.Conv2D(in_channels=in_channels, 256 | out_channels=out_channels, 257 | kernel_size=1, 258 | stride=1, 259 | padding=0, 260 | weight_attr=nn.initializer.KaimingNormal(), 261 | bias_attr=False)) 262 | else: 263 | # 2D relu + depthwise separable convolution 264 | return nn.Sequential(nn.ReLU(), 265 | nn.Conv2D(in_channels=in_channels, 266 | out_channels=in_channels, 267 | kernel_size=kernel_size, 268 | stride=stride, 269 | padding=dilation if dilation > 1 else pad, 270 | dilation=dilation, 271 | weight_attr=nn.initializer.KaimingNormal(), 272 | bias_attr=False, 273 | groups=in_channels), 274 | nn.Conv2D(in_channels=in_channels, 275 | out_channels=out_channels, 276 | kernel_size=1, 277 | stride=1, 278 | padding=0, 279 | weight_attr=nn.initializer.KaimingNormal(), 280 | bias_attr=False)) 281 | 282 | def refinement1(in_channels, out_channels): 283 | #color guidance refinement on left image or disparity stage 3 284 | net = [nn.Conv2D(in_channels=in_channels, 285 | out_channels=out_channels, 286 | kernel_size=3, 287 | stride=1, 288 | padding=1, 289 | weight_attr=nn.initializer.KaimingNormal(), 290 | bias_attr=False 291 | )] 292 | 293 | net = net + [preconv2d_depthseperated(in_channels=out_channels, 294 | out_channels=out_channels, 295 | kernel_size=3, 296 | stride=1, 297 | pad=1, 298 | dilation=2 ** (k + 1)) for k in range(4)] 299 | 300 | return nn.Sequential(*net) 301 | 302 | def refinement2(in_channels, out_channels): 303 | #color guidance refinement on concatenated features 304 | net = [preconv2d(in_channels=in_channels, 305 | out_channels=out_channels, 306 | kernel_size=3, 307 | stride=1, 308 | pad=1, 309 | dilation=8)] 310 | 311 | net = net + [preconv2d_depthseperated(in_channels=out_channels, 312 | out_channels=out_channels, 313 | kernel_size=3, 314 | stride=1, 315 | pad=1, 316 | dilation=2 ** k) for k in reversed(range(4))] 317 | 318 | net = net + [nn.Conv2D(in_channels=out_channels, 319 | out_channels=1, 320 | kernel_size=3, 321 | stride=1, 322 | padding=1, 323 | weight_attr=nn.initializer.KaimingNormal(), 324 | bias_attr=False 325 | )] 326 | 327 | return nn.Sequential(*net) --------------------------------------------------------------------------------