├── .idea ├── .gitignore ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── misc.xml ├── modules.xml ├── sshConfigs.xml ├── deployment.xml ├── webServers.xml ├── LWANet.iml └── workspace.xml ├── utils ├── __init__.py ├── __pycache__ │ ├── logger.cpython-36.pyc │ └── __init__.cpython-36.pyc ├── logger.py ├── readpfm.py ├── merge_img2video.py ├── preprocess.py └── flops_hook.py ├── dataloader ├── __init__.py ├── __pycache__ │ ├── readpfm.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── KITTILoader.cpython-36.pyc │ ├── listflowfile.cpython-36.pyc │ ├── preprocess.cpython-36.pyc │ ├── KITTILoader_mask.cpython-36.pyc │ ├── KITTILoader_video.cpython-36.pyc │ ├── KITTI_0028_sync.cpython-36.pyc │ ├── KITTIloader2012.cpython-36.pyc │ ├── KITTIloader2015.cpython-36.pyc │ ├── SecenFlowLoader.cpython-36.pyc │ ├── preprocess_change.cpython-36.pyc │ ├── KITTILoader_change.cpython-36.pyc │ ├── KITTILoader_0028_0071.cpython-36.pyc │ ├── KITTILoader_One_cycle.cpython-36.pyc │ ├── KITTILoader_supervised.cpython-36.pyc │ ├── KITTIloader2015_mask.cpython-36.pyc │ ├── KITTIloader2015_test.cpython-36.pyc │ ├── KITTIloader2015_video.cpython-36.pyc │ ├── KITTIloader_0028_sync.cpython-36.pyc │ ├── KITTILoader_submit_to_2015.cpython-36.pyc │ ├── KITTIloader2015_One_cycle.cpython-36.pyc │ ├── KITTIloader2015_supervised.cpython-36.pyc │ ├── KITTIloader_list_0028_0071.cpython-36.pyc │ ├── preprocess_submit_to_2015.cpython-36.pyc │ └── KITTIloader2015_submit_to_2015.cpython-36.pyc ├── KITTI_submission_loader.py ├── readpfm.py ├── KITTILoader_One_cycle.py ├── KITTILoader_0028_0071.py ├── KITTIloader2015_One_cycle.py ├── SecenFlowLoader.py ├── KITTILoader.py ├── KITTIdatalist.py ├── listflowfile.py └── preprocess.py ├── models ├── __pycache__ │ ├── comm.cpython-36.pyc │ ├── cspn.cpython-36.pyc │ ├── LWADNet.cpython-36.pyc │ ├── cspn_5_12.cpython-36.pyc │ ├── LWADNet_5_12.cpython-36.pyc │ ├── LWADNet_flops.cpython-36.pyc │ ├── LWADNet_submodules.cpython-36.pyc │ ├── LWADNet_submodules_BN.cpython-36.pyc │ ├── batch_normalization.cpython-36.pyc │ ├── LWADNet_submodules_test.cpython-36.pyc │ └── LWADNet_submodules_IN_FRN.cpython-36.pyc ├── feature_extraction.py ├── loss.py ├── Aggregation_submodules.py ├── LWANet.py ├── cost.py └── cspn.py ├── README.md ├── env.yaml ├── submission.py ├── main.py ├── Online_adaptation.py ├── finetune.py └── One_cycle.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/cspn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/cspn.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/LWADNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/LWADNet.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/readpfm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/readpfm.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/cspn_5_12.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/cspn_5_12.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/LWADNet_5_12.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/LWADNet_5_12.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/LWADNet_flops.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/LWADNet_flops.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTILoader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTILoader.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/listflowfile.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/listflowfile.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/preprocess.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/preprocess.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/LWADNet_submodules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/LWADNet_submodules.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTILoader_mask.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTILoader_mask.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTILoader_video.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTILoader_video.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTI_0028_sync.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTI_0028_sync.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTIloader2012.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader2012.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTIloader2015.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader2015.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/SecenFlowLoader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/SecenFlowLoader.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/preprocess_change.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/preprocess_change.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/LWADNet_submodules_BN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/LWADNet_submodules_BN.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/batch_normalization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/batch_normalization.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTILoader_change.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTILoader_change.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/LWADNet_submodules_test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/LWADNet_submodules_test.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTILoader_0028_0071.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTILoader_0028_0071.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTILoader_One_cycle.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTILoader_One_cycle.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTILoader_supervised.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTILoader_supervised.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTIloader2015_mask.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader2015_mask.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTIloader2015_test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader2015_test.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTIloader2015_video.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader2015_video.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTIloader_0028_sync.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader_0028_sync.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/LWADNet_submodules_IN_FRN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/LWADNet_submodules_IN_FRN.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTILoader_submit_to_2015.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTILoader_submit_to_2015.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTIloader2015_One_cycle.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader2015_One_cycle.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTIloader2015_supervised.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader2015_supervised.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTIloader_list_0028_0071.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader_list_0028_0071.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/preprocess_submit_to_2015.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/preprocess_submit_to_2015.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/KITTIloader2015_submit_to_2015.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader2015_submit_to_2015.cpython-36.pyc -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/sshConfigs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | -------------------------------------------------------------------------------- /.idea/LWANet.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 20 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | 5 | def setup_logger(filepath): 6 | file_formatter = logging.Formatter( 7 | "[%(asctime)s %(filename)s:%(lineno)s] %(levelname)-8s %(message)s", 8 | datefmt='%Y-%m-%d %H:%M:%S', 9 | ) 10 | logger = logging.getLogger('example') 11 | handler = logging.StreamHandler() 12 | handler.setFormatter(file_formatter) 13 | logger.addHandler(handler) 14 | 15 | file_handle_name = "file" 16 | if file_handle_name in [h.name for h in logger.handlers]: 17 | return 18 | if os.path.dirname(filepath) is not '': 19 | if not os.path.isdir(os.path.dirname(filepath)): 20 | os.makedirs(os.path.dirname(filepath)) 21 | file_handle = logging.FileHandler(filename=filepath, mode="a") 22 | file_handle.set_name(file_handle_name) 23 | file_handle.setFormatter(file_formatter) 24 | logger.addHandler(file_handle) 25 | logger.setLevel(logging.DEBUG) 26 | return logger -------------------------------------------------------------------------------- /dataloader/KITTI_submission_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | 4 | IMG_EXTENSIONS = [ 5 | '.jpg', '.JPG', '.jpeg', '.JPEG', 6 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 7 | ] 8 | 9 | 10 | def is_image_file(filename): 11 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 12 | 13 | 14 | def dataloader2015(filepath): 15 | 16 | left_fold = 'image_2/' 17 | right_fold = 'image_3/' 18 | 19 | 20 | image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] 21 | 22 | 23 | left_test = [filepath+left_fold+img for img in image] 24 | right_test = [filepath+right_fold+img for img in image] 25 | 26 | return left_test, right_test 27 | 28 | 29 | def dataloader2012(filepath): 30 | 31 | left_fold = 'colored_0/' 32 | right_fold = 'colored_1/' 33 | 34 | 35 | image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] 36 | 37 | 38 | left_test = [filepath+left_fold+img for img in image] 39 | right_test = [filepath+right_fold+img for img in image] 40 | 41 | return left_test, right_test 42 | -------------------------------------------------------------------------------- /utils/readpfm.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import sys 4 | 5 | 6 | def readPFM(file): 7 | file = open(file, 'rb') 8 | 9 | color = None 10 | width = None 11 | height = None 12 | scale = None 13 | endian = None 14 | 15 | header = file.readline().rstrip() 16 | if header == 'PF': 17 | color = True 18 | elif header == 'Pf': 19 | color = False 20 | else: 21 | raise Exception('Not a PFM file.') 22 | 23 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline()) 24 | if dim_match: 25 | width, height = map(int, dim_match.groups()) 26 | else: 27 | raise Exception('Malformed PFM header.') 28 | 29 | scale = float(file.readline().rstrip()) 30 | if scale < 0: # little-endian 31 | endian = '<' 32 | scale = -scale 33 | else: 34 | endian = '>' # big-endian 35 | 36 | data = np.fromfile(file, endian + 'f') 37 | shape = (height, width, 3) if color else (height, width) 38 | 39 | data = np.reshape(data, shape) 40 | data = np.flipud(data) 41 | return data, scale 42 | 43 | -------------------------------------------------------------------------------- /dataloader/readpfm.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import sys 4 | 5 | 6 | def readPFM(file): 7 | file = open(file, 'rb') 8 | 9 | color = None 10 | width = None 11 | height = None 12 | scale = None 13 | endian = None 14 | 15 | header = file.readline().rstrip() 16 | if header == b'PF': 17 | color = True 18 | elif header == b'Pf': 19 | color = False 20 | else: 21 | raise Exception('Not a PFM file.') 22 | 23 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 24 | if dim_match: 25 | width, height = map(int, dim_match.groups()) 26 | else: 27 | raise Exception('Malformed PFM header.') 28 | 29 | scale = float(file.readline().rstrip()) 30 | if scale < 0: # little-endian 31 | endian = '<' 32 | scale = -scale 33 | else: 34 | endian = '>' # big-endian 35 | 36 | data = np.fromfile(file, endian + 'f') 37 | shape = (height, width, 3) if color else (height, width) 38 | 39 | data = np.reshape(data, shape) 40 | data = np.flipud(data) 41 | file.close() 42 | return data, scale 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LWANet 2 | This repository contains the codes for our paper: [Light-weight Network for Real-time Adaptive Stereo Depth Estimation](https://www.sciencedirect.com/science/article/pii/S0925231221002599) 3 | 4 | # Abstract 5 | Self-supervised learning methods have been proved effective in the task of real-time stereo 6 | depth estimation with the requirement of lower memory space and less computational cost. In this 7 | paper, a light-weight adaptive network (LWANet) is proposed by combining the self-supervised 8 | learning method to perform online adaptive stereo depth estimation for low computation cost and 9 | low GPU memory space. Instead of a regular 3D convolution, the pseudo 3D convolution is 10 | employed in the proposed light-weight network to aggregate the cost volume for achieving a better 11 | balance between the accuracy and the computational cost. Moreover, based on U-Net architecture, 12 | the downsample feature extractor is combined with a refined convolutional spatial propagation 13 | network (CSPN) to further refine the estimation accuracy with little memory space and 14 | computational cost. Extensive experiments demonstrate that the proposed LWANet effectively 15 | alleviates the domain shift problem by online updating the neural network, which is suitable for 16 | embedded devices such as NVIDIA Jetson TX2. 17 | 18 | # Usage 19 | 20 | To be updated 21 | 22 | # Citation 23 | 24 | If you find this is useful, wecome to cite with 25 | 26 | ``` 27 | Gan, W., Wong, P. K., Yu, G., Zhao, R., & Vong, C. M. (2021). Light-weight Network for Real-time Adaptive Stereo Depth Estimation. Neurocomputing. 28 | ``` 29 | 30 | # Acknowledgement 31 | 32 | Many thanks to authors of [AnyNet](https://github.com/mileyan/AnyNet) and [CSPN](https://github.com/XinJCheng/CSPN) for open-sourcing the code. 33 | -------------------------------------------------------------------------------- /dataloader/KITTILoader_One_cycle.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image, ImageOps 3 | import numpy as np 4 | from . import preprocess 5 | 6 | IMG_EXTENSIONS = [ 7 | '.jpg', '.JPG', '.jpeg', '.JPEG', 8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 9 | ] 10 | 11 | def is_image_file(filename): 12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 13 | 14 | def default_loader(path): 15 | return Image.open(path).convert('RGB') 16 | 17 | def disparity_loader(path): 18 | return Image.open(path) 19 | 20 | 21 | class myImageFloder(data.Dataset): 22 | def __init__(self, left, right, left_disparity, training, loader=default_loader, dploader= disparity_loader): 23 | 24 | self.left = left 25 | self.right = right 26 | self.disp_L = left_disparity 27 | self.loader = loader 28 | self.dploader = dploader 29 | self.training = training 30 | 31 | def __getitem__(self, index): 32 | left = self.left[index] 33 | right = self.right[index] 34 | disp_L= self.disp_L[index] 35 | 36 | left_img = self.loader(left) 37 | right_img = self.loader(right) 38 | dataL = self.dploader(disp_L) 39 | 40 | 41 | 42 | w, h = left_img.size 43 | 44 | left_img = left_img.crop((w - 1216, h - 320, w, h)) 45 | right_img = right_img.crop((w - 1216, h - 320, w, h)) 46 | 47 | 48 | 49 | dataL = dataL.crop((w-1216, h-320, w, h)) 50 | 51 | dataL = np.ascontiguousarray(dataL,dtype=np.float32)/256 52 | 53 | processed = preprocess.get_transform(augment=False) 54 | left_img = processed(left_img) 55 | right_img = processed(right_img) 56 | 57 | 58 | 59 | return left_img, right_img, dataL 60 | 61 | 62 | 63 | 64 | 65 | 66 | def __len__(self): 67 | return len(self.left) 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /dataloader/KITTILoader_0028_0071.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image, ImageOps 3 | import numpy as np 4 | from . import preprocess 5 | 6 | IMG_EXTENSIONS = [ 7 | '.jpg', '.JPG', '.jpeg', '.JPEG', 8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 9 | ] 10 | 11 | def is_image_file(filename): 12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 13 | 14 | def default_loader(path): 15 | return Image.open(path).convert('RGB') 16 | 17 | def disparity_loader(path): 18 | return Image.open(path) 19 | 20 | 21 | class myImageFloder(data.Dataset): 22 | def __init__(self, left, right, left_disparity, training, loader=default_loader, dploader= disparity_loader): 23 | 24 | self.left = left 25 | self.right = right 26 | self.disp_L = left_disparity 27 | self.loader = loader 28 | self.dploader = dploader 29 | self.training = training 30 | 31 | def __getitem__(self, index): 32 | left = self.left[index] 33 | right = self.right[index] 34 | disp_L= self.disp_L[index] 35 | 36 | left_img = self.loader(left) 37 | right_img = self.loader(right) 38 | dataL = self.dploader(disp_L) 39 | 40 | 41 | # full image 42 | w, h = left_img.size 43 | left_img = left_img.crop((w - 1216, h - 320, w, h)) 44 | right_img = right_img.crop((w - 1216, h - 320, w, h)) 45 | dataL = dataL.crop((w - 1216, h - 320, w, h)) 46 | dataL = np.ascontiguousarray(dataL, dtype=np.float32) / 256 47 | 48 | # 0028 49 | #dataL = 0.54 * 707 / dataL 50 | 51 | # 0071 52 | dataL = 0.54 * 718 / dataL 53 | 54 | processed = preprocess.get_transform(augment=False) 55 | left_img = processed(left_img) 56 | right_img = processed(right_img) 57 | 58 | 59 | return left_img, right_img, dataL 60 | 61 | 62 | def __len__(self): 63 | return len(self.left) 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /dataloader/KITTIloader2015_One_cycle.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def dataloader(filepath, log): 18 | 19 | left_fold = 'image_2/' 20 | right_fold = 'image_3/' 21 | disp_L = 'disp_occ_0/' 22 | #disp_R = 'disp_occ_1/' 23 | # 24 | # left_fold = 'image_02/data/' 25 | # right_fold = 'image_03/data/' 26 | # disp_L = 'data_depth_annotated/2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/' 27 | #disp_R = 'disp_occ_1/' 28 | 29 | image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] 30 | #image = [img for img in os.listdir(filepath + left_fold) if img.find('000000_10')] 31 | #print('image 0:', len(image)) 32 | 33 | all_index = np.arange(200) 34 | #np.random.seed(2) 35 | #np.random.shuffle(all_index) 36 | #print('all_index:', all_index) 37 | vallist = all_index[:40] 38 | 39 | log.info(vallist) 40 | val = ['{:06d}_10.png'.format(x) for x in vallist] 41 | #train = [x for x in image if x not in val] 42 | train = [x for x in image if x == '000128_10.png'] 43 | print('train :', train[0]) 44 | 45 | 46 | 47 | 48 | 49 | left_train = [filepath+left_fold+img for img in train] 50 | right_train = [filepath+right_fold+img for img in train] 51 | disp_train_L = [filepath+disp_L+img for img in train] 52 | #disp_train_R = [filepath+disp_R+img for img in train] 53 | 54 | left_val = [filepath+left_fold+img for img in val] 55 | right_val = [filepath+right_fold+img for img in val] 56 | disp_val_L = [filepath+disp_L+img for img in val] 57 | #disp_val_R = [filepath+disp_R+img for img in val] 58 | 59 | 60 | return left_train, right_train, disp_train_L, left_val, right_val, disp_val_L 61 | #return left_train, right_train, disp_train_L, disp_train_R, left_val, right_val, disp_val_L, disp_val_R -------------------------------------------------------------------------------- /dataloader/SecenFlowLoader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | import torch 5 | import torchvision.transforms as transforms 6 | import random 7 | from PIL import Image, ImageOps 8 | from . import preprocess 9 | from . import listflowfile as lt 10 | from . import readpfm as rp 11 | import numpy as np 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | ] 17 | 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | 23 | def default_loader(path): 24 | return Image.open(path).convert('RGB') 25 | 26 | 27 | def disparity_loader(path): 28 | return rp.readPFM(path) 29 | 30 | 31 | class myImageFloder(data.Dataset): 32 | def __init__(self, left, right, left_disparity, training, loader=default_loader, dploader=disparity_loader): 33 | 34 | self.left = left 35 | self.right = right 36 | self.disp_L = left_disparity 37 | self.loader = loader 38 | self.dploader = dploader 39 | self.training = training 40 | 41 | def __getitem__(self, index): 42 | left = self.left[index] 43 | right = self.right[index] 44 | disp_L = self.disp_L[index] 45 | 46 | left_img = self.loader(left) 47 | right_img = self.loader(right) 48 | dataL, scaleL = self.dploader(disp_L) 49 | dataL = np.ascontiguousarray(dataL, dtype=np.float32) 50 | 51 | if self.training: 52 | w, h = left_img.size 53 | #th, tw = 256, 512 54 | th, tw = 512, 960 55 | 56 | x1 = random.randint(0, w - tw) 57 | y1 = random.randint(0, h - th) 58 | 59 | left_img = left_img.crop((x1, y1, x1 + tw, y1 + th)) 60 | right_img = right_img.crop((x1, y1, x1 + tw, y1 + th)) 61 | 62 | dataL = dataL[y1:y1 + th, x1:x1 + tw] 63 | 64 | processed = preprocess.get_transform(augment=False) 65 | left_img = processed(left_img) 66 | right_img = processed(right_img) 67 | 68 | return left_img, right_img, dataL 69 | else: 70 | w, h = left_img.size 71 | left_img = left_img.crop((w - 960, h - 544, w, h)) 72 | right_img = right_img.crop((w - 960, h - 544, w, h)) 73 | processed = preprocess.get_transform(augment=False) 74 | left_img = processed(left_img) 75 | right_img = processed(right_img) 76 | 77 | return left_img, right_img, dataL 78 | 79 | def __len__(self): 80 | return len(self.left) 81 | -------------------------------------------------------------------------------- /dataloader/KITTILoader.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.utils.data as data 3 | import random 4 | from PIL import Image, ImageOps 5 | import numpy as np 6 | from . import preprocess 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | def is_image_file(filename): 14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 15 | 16 | def default_loader(path): 17 | return Image.open(path).convert('RGB') 18 | 19 | def disparity_loader(path): 20 | return Image.open(path) 21 | 22 | 23 | class myImageFloder(data.Dataset): 24 | def __init__(self, left, right, left_disparity, training, loader=default_loader, dploader= disparity_loader): 25 | 26 | self.left = left 27 | self.right = right 28 | self.disp_L = left_disparity 29 | self.loader = loader 30 | self.dploader = dploader 31 | self.training = training 32 | 33 | def __getitem__(self, index): 34 | left = self.left[index] 35 | right = self.right[index] 36 | disp_L= self.disp_L[index] 37 | 38 | left_img = self.loader(left) 39 | right_img = self.loader(right) 40 | dataL = self.dploader(disp_L) 41 | 42 | 43 | if self.training: 44 | w, h = left_img.size 45 | #print(' w, h:', w, h) 46 | #th, tw = 256, 512 47 | th, tw = 288, 624 48 | 49 | 50 | x1 = random.randint(0, w - tw) 51 | 52 | y1 = random.randint(0, h - th) 53 | 54 | left_img = left_img.crop((x1, y1, x1 + tw, y1 + th)) 55 | right_img = right_img.crop((x1, y1, x1 + tw, y1 + th)) 56 | 57 | 58 | 59 | dataL = np.ascontiguousarray(dataL,dtype=np.float32)/256 60 | dataL = dataL[y1:y1 + th, x1:x1 + tw] 61 | 62 | processed = preprocess.get_transform(augment=False) 63 | left_img = processed(left_img) 64 | right_img = processed(right_img) 65 | 66 | return left_img, right_img, dataL 67 | 68 | 69 | 70 | 71 | else: 72 | w, h = left_img.size 73 | 74 | left_img = left_img.crop((w-1232, h-368, w, h)) 75 | right_img = right_img.crop((w-1232, h-368, w, h)) 76 | 77 | dataL = dataL.crop((w-1232, h-368, w, h)) 78 | 79 | dataL = np.ascontiguousarray(dataL,dtype=np.float32)/256 80 | 81 | processed = preprocess.get_transform(augment=False) 82 | left_img = processed(left_img) 83 | right_img = processed(right_img) 84 | 85 | 86 | 87 | return left_img, right_img, dataL 88 | 89 | 90 | def __len__(self): 91 | return len(self.left) 92 | 93 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 10 | 11 | 16 | 17 | 18 | 19 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 1610205822897 39 | 45 | 46 | 47 | 48 | 50 | -------------------------------------------------------------------------------- /dataloader/KITTIdatalist.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import random 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader2012(filepath, log, split=False): 19 | 20 | left_fold = 'colored_0/' 21 | right_fold = 'colored_1/' 22 | disp_noc = 'disp_occ/' 23 | 24 | image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] 25 | random.shuffle(image) 26 | 27 | 28 | if not split: 29 | 30 | np.random.seed(2) 31 | random.shuffle(image) 32 | train = image[:] 33 | val = image[160:] 34 | 35 | else: 36 | 37 | train = image[:160] 38 | val = image[160:] 39 | 40 | 41 | 42 | log.info(val) 43 | 44 | left_train = [filepath+left_fold+img for img in train] 45 | right_train = [filepath+right_fold+img for img in train] 46 | disp_train = [filepath+disp_noc+img for img in train] 47 | 48 | 49 | left_val = [filepath+left_fold+img for img in val] 50 | right_val = [filepath+right_fold+img for img in val] 51 | disp_val = [filepath+disp_noc+img for img in val] 52 | 53 | return left_train, right_train, disp_train, left_val, right_val, disp_val 54 | 55 | 56 | 57 | def dataloader2015(filepath, log, split = False): 58 | 59 | left_fold = 'image_2/' 60 | right_fold = 'image_3/' 61 | disp_L = 'disp_occ_0/' 62 | 63 | image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] 64 | 65 | all_index = np.arange(200) 66 | np.random.seed(2) 67 | np.random.shuffle(all_index) 68 | #print('all_index:', all_index) 69 | vallist = all_index[:40] 70 | 71 | log.info(vallist) 72 | val = ['{:06d}_10.png'.format(x) for x in vallist] 73 | 74 | if split: 75 | train = [x for x in image if x not in val] 76 | # train = [x for x in image if x not in val] 77 | 78 | else: 79 | train = [x for x in image] 80 | 81 | 82 | 83 | left_train = [filepath+left_fold+img for img in train] 84 | right_train = [filepath+right_fold+img for img in train] 85 | disp_train_L = [filepath+disp_L+img for img in train] 86 | #disp_train_R = [filepath+disp_R+img for img in train] 87 | 88 | left_val = [filepath+left_fold+img for img in val] 89 | right_val = [filepath+right_fold+img for img in val] 90 | disp_val_L = [filepath+disp_L+img for img in val] 91 | #disp_val_R = [filepath+disp_R+img for img in val] 92 | 93 | 94 | return left_train, right_train, disp_train_L, left_val, right_val, disp_val_L 95 | 96 | 97 | 98 | 99 | def dataloader_adaptation(filepath, datatype): 100 | 101 | # 0028 102 | left_fold = 'raw_image/image_02/data/' 103 | right_fold = 'raw_image/image_03/data/' # w, h: 1226 370 104 | disp_L = 'disparity/image_02/' 105 | 106 | path_list = os.listdir(filepath + left_fold) 107 | path_list.sort(key=lambda x: int(x.split('.')[0])) 108 | image = [img for img in path_list] 109 | 110 | 111 | #0028 112 | if datatype == "0028": 113 | image = image[5:2005] 114 | 115 | 116 | elif datatype == "0071": 117 | # 0071 118 | image = image[5:-6] 119 | 120 | 121 | train = [x for x in image] 122 | 123 | left_train = [filepath+left_fold+img for img in train] 124 | right_train = [filepath+right_fold+img for img in train] 125 | disp_train_L = [filepath+disp_L+img for img in train] 126 | 127 | return left_train, right_train, -------------------------------------------------------------------------------- /models/feature_extraction.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | from __future__ import print_function 3 | import torch.nn as nn 4 | 5 | 6 | class F1(nn.Module): 7 | def __init__(self): 8 | super(F1, self).__init__() 9 | # feature extraction 10 | self.init_feature = nn.Sequential( 11 | 12 | # 6-24 13 | nn.Conv2d(3, 4, 3, 1, 1, bias=False), 14 | nn.BatchNorm2d(4), 15 | nn.ELU(inplace=True), 16 | nn.Conv2d(4, 4, 3, 2, 1, bias=False), 17 | nn.Conv2d(4, 8, 3, 1, 1, bias=False), 18 | 19 | ) 20 | 21 | def forward(self, x_left): 22 | 23 | buffer_left = self.init_feature(x_left) 24 | 25 | return buffer_left 26 | 27 | 28 | 29 | class F2(nn.Module): 30 | def __init__(self): 31 | super(F2, self).__init__() 32 | 33 | self.init_feature = nn.Sequential( 34 | 35 | 36 | nn.MaxPool2d(2, 2), 37 | nn.BatchNorm2d(8), 38 | nn.ELU(inplace=True), 39 | 40 | nn.Conv2d(8, 12, 3, 1, 1, bias=False), 41 | nn.BatchNorm2d(12), 42 | nn.ELU(inplace=True), 43 | nn.Conv2d(12, 12, 3, 1, 1, bias=False), 44 | 45 | ) 46 | 47 | def forward(self, x_left): 48 | 49 | buffer_left = self.init_feature(x_left) 50 | 51 | return buffer_left 52 | 53 | 54 | class F3(nn.Module): 55 | def __init__(self): 56 | super(F3, self).__init__() 57 | 58 | self.init_feature = nn.Sequential( 59 | 60 | nn.MaxPool2d(2, 2), 61 | nn.BatchNorm2d(12), 62 | nn.ELU(inplace=True), 63 | 64 | nn.Conv2d(12, 16, 3, 1, 1, bias=False), 65 | nn.BatchNorm2d(16), 66 | nn.ELU(inplace=True), 67 | nn.Conv2d(16, 16, 3, 1, 1, bias=False), 68 | 69 | ) 70 | 71 | def forward(self, x_left): 72 | 73 | buffer_left = self.init_feature(x_left) 74 | 75 | return buffer_left 76 | 77 | 78 | 79 | 80 | class F3_UP(nn.Module): 81 | def __init__(self): 82 | super(F3_UP, self).__init__() 83 | self.init_feature = nn.Sequential( 84 | 85 | nn.Conv2d(16, 16, 3, 1, 1, bias=False), 86 | 87 | nn.BatchNorm2d(16), 88 | 89 | nn.ELU(inplace=True), 90 | 91 | nn.ConvTranspose2d(16, 12, 3, 2, 1, output_padding=1, bias=False), 92 | ) 93 | 94 | def forward(self, x_left): 95 | 96 | buffer_left = self.init_feature(x_left) 97 | 98 | return buffer_left 99 | 100 | 101 | class F2_UP(nn.Module): 102 | def __init__(self): 103 | super(F2_UP, self).__init__() 104 | 105 | # cat 106 | self.init_feature = nn.Sequential( 107 | 108 | nn.BatchNorm2d(24), 109 | nn.ELU(inplace=True), 110 | nn.ConvTranspose2d(24, 8, 3, 2, 1, output_padding=1, bias=False), 111 | ) 112 | 113 | 114 | 115 | def forward(self, x_left): 116 | ### feature extraction 117 | buffer_left = self.init_feature(x_left) 118 | 119 | return buffer_left 120 | 121 | 122 | class F1_UP(nn.Module): 123 | def __init__(self): 124 | super(F1_UP, self).__init__() 125 | # cat 126 | self.init_feature = nn.Sequential( 127 | 128 | nn.BatchNorm2d(16), 129 | nn.ELU(inplace=True), 130 | nn.ConvTranspose2d(16, 8, 3, 2, 1, output_padding=1, bias=False), 131 | 132 | ) 133 | 134 | 135 | def forward(self, x_left): 136 | ### feature extraction 137 | buffer_left = self.init_feature(x_left) 138 | 139 | return buffer_left 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: LWANet 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - blas=1.0=mkl 9 | - bzip2=1.0.8=h7b6447c_0 10 | - ca-certificates=2020.6.24=0 11 | - cairo=1.14.12=h8948797_3 12 | - certifi=2020.6.20=py37_0 13 | - cffi=1.12.3=py37h2e261b9_0 14 | - cloudpickle=1.2.1=py_0 15 | - cuda80=1.0=h205658b_0 16 | - cudatoolkit=10.0.130=0 17 | - cycler=0.10.0=py37_0 18 | - cytoolz=0.10.0=py37h7b6447c_0 19 | - dask-core=2.3.0=py_0 20 | - dbus=1.13.6=h746ee38_0 21 | - decorator=4.4.0=py37_1 22 | - expat=2.2.6=he6710b0_0 23 | - ffmpeg=4.0=hcdf2ecd_0 24 | - fontconfig=2.13.0=h9420a91_0 25 | - freeglut=3.0.0=hf484d3e_5 26 | - freetype=2.9.1=h8a8886c_1 27 | - glib=2.56.2=hd408876_0 28 | - graphite2=1.3.13=h23475e2_0 29 | - gst-plugins-base=1.14.0=hbbd80ab_1 30 | - gstreamer=1.14.0=hb453b48_1 31 | - harfbuzz=1.8.8=hffaf4a1_0 32 | - hdf5=1.10.2=hba1933b_1 33 | - icu=58.2=h9c2bf20_1 34 | - imageio=2.5.0=py37_0 35 | - intel-openmp=2019.4=243 36 | - jasper=2.0.14=h07fcdf6_1 37 | - jpeg=9b=h024ee3a_2 38 | - kiwisolver=1.1.0=py37he6710b0_0 39 | - libedit=3.1.20181209=hc058e9b_0 40 | - libffi=3.2.1=hd88cf55_4 41 | - libgcc-ng=9.1.0=hdf63c60_0 42 | - libgfortran-ng=7.3.0=hdf63c60_0 43 | - libglu=9.0.0=hf484d3e_1 44 | - libopencv=3.4.2=hb342d67_1 45 | - libopus=1.3=h7b6447c_0 46 | - libpng=1.6.37=hbc83047_0 47 | - libstdcxx-ng=9.1.0=hdf63c60_0 48 | - libtiff=4.0.10=h2733197_2 49 | - libuuid=1.0.3=h1bed415_2 50 | - libvpx=1.7.0=h439df22_0 51 | - libxcb=1.13=h1bed415_1 52 | - libxml2=2.9.9=hea5a465_1 53 | - matplotlib=3.1.1=py37h5429711_0 54 | - mkl=2019.4=243 55 | - mkl-service=2.0.2=py37h7b6447c_0 56 | - mkl_fft=1.0.14=py37ha843d7b_0 57 | - mkl_random=1.0.2=py37hd81dba3_0 58 | - ncurses=6.1=he6710b0_1 59 | - networkx=2.3=py_0 60 | - ninja=1.9.0=py37hfd86e86_0 61 | - olefile=0.46=py37_0 62 | - opencv=3.4.2=py37h6fd60c2_1 63 | - openssl=1.1.1g=h7b6447c_0 64 | - pcre=8.43=he6710b0_0 65 | - pillow=6.1.0=py37h34e0f95_0 66 | - pip=19.2.2=py37_0 67 | - pixman=0.38.0=h7b6447c_0 68 | - py-opencv=3.4.2=py37hb342d67_1 69 | - pycparser=2.19=py37_0 70 | - pyparsing=2.4.2=py_0 71 | - pyqt=5.9.2=py37h05f1152_2 72 | - python=3.7.4=h265db76_1 73 | - python-dateutil=2.8.0=py37_0 74 | - pytorch=1.0.0=py3.7_cuda8.0.61_cudnn7.1.2_1 75 | - pytz=2019.2=py_0 76 | - pywavelets=1.0.3=py37hdd07704_1 77 | - qt=5.9.7=h5867ecd_1 78 | - readline=7.0=h7b6447c_5 79 | - scikit-image=0.15.0=py37he6710b0_0 80 | - scipy=1.3.1=py37h7c811a0_0 81 | - setuptools=41.0.1=py37_0 82 | - sip=4.19.8=py37hf484d3e_0 83 | - six=1.12.0=py37_0 84 | - sqlite=3.29.0=h7b6447c_0 85 | - tk=8.6.8=hbc83047_0 86 | - toolz=0.10.0=py_0 87 | - torchvision=0.2.1=py_2 88 | - tornado=6.0.3=py37h7b6447c_0 89 | - wheel=0.33.4=py37_0 90 | - xz=5.2.4=h14c3975_4 91 | - zlib=1.2.11=h7b6447c_3 92 | - zstd=1.3.7=h0b5b093_0 93 | - pip: 94 | - absl-py==0.8.0 95 | - apex==0.1 96 | - astor==0.8.0 97 | - bleach==1.5.0 98 | - chardet==3.0.4 99 | - future==0.17.1 100 | - gast==0.2.2 101 | - google-pasta==0.1.7 102 | - grpcio==1.23.0 103 | - h5py==2.9.0 104 | - html5lib==0.9999999 105 | - idna==2.10 106 | - imageio-ffmpeg==0.4.2 107 | - keras-applications==1.0.8 108 | - keras-preprocessing==1.1.0 109 | - markdown==3.1.1 110 | - moviepy==1.0.3 111 | - numpy==1.19.4 112 | - proglog==0.1.9 113 | - protobuf==3.9.1 114 | - requests==2.25.1 115 | - tb-nightly==1.15.0a20190902 116 | - termcolor==1.1.0 117 | - thop==0.0.31-2005241907 118 | - tqdm==4.54.1 119 | - urllib3==1.26.2 120 | - werkzeug==0.15.5 121 | - wrapt==1.11.2 122 | - yum==0.0.1 123 | prefix: /home/wsgan/anaconda3/envs/aanet 124 | -------------------------------------------------------------------------------- /dataloader/listflowfile.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 10 | ] 11 | 12 | 13 | def is_image_file(filename): 14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 15 | 16 | def dataloader(filepath): 17 | 18 | classes = [d for d in os.listdir(filepath) if os.path.isdir(os.path.join(filepath, d))] 19 | image = [img for img in classes if img.find('frames_cleanpass') > -1] 20 | disp = [dsp for dsp in classes if dsp.find('disparity') > -1] 21 | 22 | print('len image:',len(image)) 23 | 24 | monkaa_path = filepath + '' + [x for x in image if 'monkaa' in x][0] 25 | monkaa_disp = filepath + [x for x in disp if 'monkaa' in x][0] 26 | # monkaa_path = filepath + 'monkaa_frames_cleanpass' 27 | # monkaa_disp = filepath + 'monkaa_disparity' 28 | 29 | monkaa_dir = os.listdir(monkaa_path) 30 | 31 | 32 | all_left_img=[] 33 | all_right_img=[] 34 | all_left_disp = [] 35 | test_left_img=[] 36 | test_right_img=[] 37 | test_left_disp = [] 38 | 39 | 40 | for dd in monkaa_dir: 41 | 42 | for im in os.listdir(monkaa_path+'/'+dd+'/left/'): 43 | if is_image_file(monkaa_path+'/'+dd+'/left/'+im): 44 | all_left_img.append(monkaa_path+'/'+dd+'/left/'+im) 45 | all_left_disp.append(monkaa_disp+'/'+dd+'/left/'+im.split(".")[0]+'.pfm') 46 | 47 | for im in os.listdir(monkaa_path+'/'+dd+'/right/'): 48 | if is_image_file(monkaa_path+'/'+dd+'/right/'+im): 49 | all_right_img.append(monkaa_path+'/'+dd+'/right/'+im) 50 | 51 | 52 | 53 | 54 | 55 | flying_path = filepath + [x for x in image if x == 'frames_cleanpass'][0] 56 | flying_disp = filepath + [x for x in disp if x == 'frames_disparity'][0] 57 | flying_dir = flying_path+'/TRAIN/' 58 | subdir = ['A','B','C'] 59 | 60 | for ss in subdir: 61 | flying = os.listdir(flying_dir+ss) 62 | 63 | for ff in flying: 64 | imm_l = os.listdir(flying_dir+ss+'/'+ff+'/left/') 65 | for im in imm_l: 66 | if is_image_file(flying_dir+ss+'/'+ff+'/left/'+im): 67 | all_left_img.append(flying_dir+ss+'/'+ff+'/left/'+im) 68 | 69 | all_left_disp.append(flying_disp+'/TRAIN/'+ss+'/'+ff+'/left/'+im.split(".")[0]+'.pfm') 70 | 71 | if is_image_file(flying_dir+ss+'/'+ff+'/right/'+im): 72 | all_right_img.append(flying_dir+ss+'/'+ff+'/right/'+im) 73 | 74 | flying_dir = flying_path+'/TEST/' 75 | 76 | subdir = ['A','B','C'] 77 | 78 | 79 | for ss in subdir: 80 | flying = os.listdir(flying_dir+ss) 81 | 82 | for ff in flying: 83 | imm_l = os.listdir(flying_dir+ss+'/'+ff+'/left/') 84 | for im in imm_l: 85 | if is_image_file(flying_dir+ss+'/'+ff+'/left/'+im): 86 | test_left_img.append(flying_dir+ss+'/'+ff+'/left/'+im) 87 | 88 | test_left_disp.append(flying_disp+'/TEST/'+ss+'/'+ff+'/left/'+im.split(".")[0]+'.pfm') 89 | 90 | if is_image_file(flying_dir+ss+'/'+ff+'/right/'+im): 91 | test_right_img.append(flying_dir+ss+'/'+ff+'/right/'+im) 92 | 93 | 94 | 95 | 96 | 97 | driving_dir = filepath + [x for x in image if 'driving' in x][0] + '/' 98 | driving_disp = filepath + [x for x in disp if 'driving' in x][0] 99 | 100 | subdir1 = ['35mm_focallength','15mm_focallength'] 101 | subdir2 = ['scene_backwards','scene_forwards'] 102 | subdir3 = ['fast','slow'] 103 | 104 | 105 | for i in subdir1: 106 | for j in subdir2: 107 | for k in subdir3: 108 | imm_l = os.listdir(driving_dir+i+'/'+j+'/'+k+'/left/') 109 | for im in imm_l: 110 | if is_image_file(driving_dir+i+'/'+j+'/'+k+'/left/'+im): 111 | all_left_img.append(driving_dir+i+'/'+j+'/'+k+'/left/'+im) 112 | all_left_disp.append(driving_disp+'/'+i+'/'+j+'/'+k+'/left/'+im.split(".")[0]+'.pfm') 113 | 114 | if is_image_file(driving_dir+i+'/'+j+'/'+k+'/right/'+im): 115 | all_right_img.append(driving_dir+i+'/'+j+'/'+k+'/right/'+im) 116 | 117 | 118 | print('len all_left_img:', len(all_left_img)) 119 | print('len test_left_img:', len(test_left_img)) 120 | 121 | return all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp 122 | 123 | 124 | -------------------------------------------------------------------------------- /utils/merge_img2video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from PIL import Image, ImageDraw, ImageFont 4 | import numpy as np 5 | import pdb 6 | 7 | def merge_img2video(image_path, video_path): 8 | # path = image_path # 图片序列所在目录,文件名:0.jpg 1.jpg ... 9 | # dst_path = video_path #r'F:\dst\result.mp4' # 生成的视频路径 10 | 11 | filelist = os.listdir(image_path) 12 | filepref = [os.path.splitext(f)[0] for f in filelist] 13 | 14 | 15 | 16 | filepref.sort(key = int) # 按数字文件名排序 17 | #filepref= sorted(filepref,key=lambda x: int(x[:-6])) # 按数字文件名排序 18 | 19 | #pdb.set_trace() 20 | 21 | filelist = [f + '.png' for f in filepref] 22 | 23 | # size = (int(videoCapture.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)), 24 | # int(videoCapture.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT))) 25 | width = 1216 26 | height = 320 27 | 28 | # width = 1238 29 | # height = 374 30 | fps = 30 31 | 32 | vw = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'DIVX'), fps, (width , height)) 33 | 34 | #for file in filelist[5:-6]: 35 | for file in filelist: 36 | if file.endswith('.png'): 37 | file = os.path.join(image_path, file) 38 | print("file:", file) 39 | img = cv2.imread(file) 40 | print("img:", img.shape) 41 | 42 | 43 | # img = img[54:,22:,:] 44 | # #img = img[50:, 10:, :] 45 | 46 | 47 | print("img:", img.shape) 48 | #img = np.hstack((img, img)) # 如果并排两列显示 49 | vw.write(img) 50 | 51 | 52 | vw.release() 53 | 54 | 55 | 56 | def merge_video(): 57 | 58 | videoLeftUp = cv2.VideoCapture('/home/wsgan/LWANet/results/video/0028/raw_img_title.mp4') 59 | videoLeftDown = cv2.VideoCapture('/home/wsgan/LWANet/results/video/0028/GT_supervise/GT_supervise_subtitile.mp4') 60 | videoRightUp = cv2.VideoCapture('/home/wsgan/LWANet/results/video/0028/self_supervise/Self_supervise_subtitle.mp4') 61 | videoRightDown = cv2.VideoCapture('/home/wsgan/LWANet/results/video/0028/no_supervise/No_supervise_subtitle.mp4') 62 | 63 | fps = videoLeftUp.get(cv2.CAP_PROP_FPS) 64 | 65 | width = (int(videoLeftUp.get(cv2.CAP_PROP_FRAME_WIDTH))) 66 | height = (int(videoLeftUp.get(cv2.CAP_PROP_FRAME_HEIGHT))) 67 | 68 | videoWriter = cv2.VideoWriter('/home/wsgan/LWANet/results/video/0028/merge0028.mp4', cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), fps, (width, height)) 69 | 70 | successLeftUp, frameLeftUp = videoLeftUp.read() 71 | successLeftDown, frameLeftDown = videoLeftDown.read() 72 | successRightUp, frameRightUp = videoRightUp.read() 73 | successRightDown, frameRightDown = videoRightDown.read() 74 | 75 | while successLeftUp and successLeftDown and successRightUp and successRightDown: 76 | frameLeftUp = cv2.resize(frameLeftUp, (int(width / 2), int(height / 2)), interpolation=cv2.INTER_CUBIC) 77 | frameLeftDown = cv2.resize(frameLeftDown, (int(width / 2), int(height / 2)), interpolation=cv2.INTER_CUBIC) 78 | frameRightUp = cv2.resize(frameRightUp, (int(width / 2), int(height / 2)), interpolation=cv2.INTER_CUBIC) 79 | frameRightDown = cv2.resize(frameRightDown, (int(width / 2), int(height / 2)), interpolation=cv2.INTER_CUBIC) 80 | 81 | frameUp = np.hstack((frameLeftUp, frameRightUp)) 82 | frameDown = np.hstack((frameLeftDown, frameRightDown)) 83 | frame = np.vstack((frameUp, frameDown)) 84 | 85 | videoWriter.write(frame) 86 | successLeftUp, frameLeftUp = videoLeftUp.read() 87 | successLeftDown, frameLeftDown = videoLeftDown.read() 88 | successRightUp, frameRightUp = videoRightUp.read() 89 | successRightDown, frameRightDown = videoRightDown.read() 90 | 91 | videoWriter.release() 92 | videoLeftUp.release() 93 | videoLeftDown.release() 94 | videoRightUp.release() 95 | videoRightDown.release() 96 | 97 | 98 | 99 | def add_subtitle(video_path, save_path): 100 | 101 | cap = cv2.VideoCapture(video_path) # 读取视频 102 | 103 | # Define the codec and create VideoWriter object 104 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 105 | out = cv2.VideoWriter(save_path, fourcc, 30.0, (1216, 320)) # 输出视频参数设置 106 | 107 | while (cap.isOpened()): 108 | ret, frame = cap.read() 109 | if ret == True: 110 | # 在 frame 上显示一些信息 111 | img_PIL = Image.fromarray(frame[..., ::-1]) # 转成 array 112 | font = ImageFont.truetype('UbuntuMono-B.ttf', 40) # 字体设置,Windows系统可以在 "C:\Windows\Fonts" 下查找 113 | text1 = "Self_supervise" 114 | 115 | for i, te in enumerate(text1): 116 | # position = (50, 10 + i * 50) 117 | position = (10 + i * 20, 20 ) 118 | draw = ImageDraw.Draw(img_PIL) 119 | draw.text(position, te, font=font, fill=(255, 0, 0)) 120 | 121 | frame = cv2.cvtColor(np.asarray(img_PIL), cv2.COLOR_RGB2BGR) 122 | 123 | # write the frame 124 | #cv2.imshow('frame', frame) 125 | out.write(frame) 126 | # if cv2.waitKey(1) & 0xFF == ord('q'): 127 | # break 128 | else: 129 | break 130 | 131 | # Release everything if job is finished 132 | cap.release() 133 | out.release() 134 | #cv2.destroyAllWindows() 135 | 136 | 137 | 138 | 139 | 140 | 141 | if __name__ == "__main__": 142 | print('start') 143 | #merge_img2video('/home/wsgan/LWANet/results/video/0071/self_supervise/disparity', '/home/wsgan/LWANet/results/video/0071/self_supervise/Self_supervise.mp4' ) 144 | merge_video() 145 | #add_subtitle('/home/wsgan/LWANet/results/video/0071/self_supervise/Self_supervise.mp4', '/home/wsgan/LWANet/results/video/0071/self_supervise/Self_supervise_subtitle.mp4') 146 | print('end') -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import random 4 | 5 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 6 | 'std': [0.229, 0.224, 0.225]} 7 | 8 | #__imagenet_stats = {'mean': [0.5, 0.5, 0.5], 9 | # 'std': [0.5, 0.5, 0.5]} 10 | 11 | __imagenet_pca = { 12 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 13 | 'eigvec': torch.Tensor([ 14 | [-0.5675, 0.7192, 0.4009], 15 | [-0.5808, -0.0045, -0.8140], 16 | [-0.5836, -0.6948, 0.4203], 17 | ]) 18 | } 19 | 20 | 21 | def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats): 22 | t_list = [ 23 | transforms.ToTensor(), 24 | transforms.Normalize(**normalize), 25 | ] 26 | #if scale_size != input_size: 27 | #t_list = [transforms.Scale((960,540))] + t_list 28 | 29 | return transforms.Compose(t_list) 30 | 31 | 32 | def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 33 | t_list = [ 34 | transforms.RandomCrop(input_size), 35 | transforms.ToTensor(), 36 | transforms.Normalize(**normalize), 37 | ] 38 | if scale_size != input_size: 39 | t_list = [transforms.Scale(scale_size)] + t_list 40 | 41 | transforms.Compose(t_list) 42 | 43 | 44 | def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 45 | padding = int((scale_size - input_size) / 2) 46 | return transforms.Compose([ 47 | transforms.RandomCrop(input_size, padding=padding), 48 | transforms.RandomHorizontalFlip(), 49 | transforms.ToTensor(), 50 | transforms.Normalize(**normalize), 51 | ]) 52 | 53 | 54 | def inception_preproccess(input_size, normalize=__imagenet_stats): 55 | return transforms.Compose([ 56 | transforms.RandomSizedCrop(input_size), 57 | transforms.RandomHorizontalFlip(), 58 | transforms.ToTensor(), 59 | transforms.Normalize(**normalize) 60 | ]) 61 | def inception_color_preproccess(input_size, normalize=__imagenet_stats): 62 | return transforms.Compose([ 63 | #transforms.RandomSizedCrop(input_size), 64 | #transforms.RandomHorizontalFlip(), 65 | transforms.ToTensor(), 66 | ColorJitter( 67 | brightness=0.4, 68 | contrast=0.4, 69 | saturation=0.4, 70 | ), 71 | Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']), 72 | transforms.Normalize(**normalize) 73 | ]) 74 | 75 | 76 | def get_transform(name='imagenet', input_size=None, 77 | scale_size=None, normalize=None, augment=True): 78 | normalize = __imagenet_stats 79 | input_size = 256 80 | if augment: 81 | return inception_color_preproccess(input_size, normalize=normalize) 82 | else: 83 | return scale_crop(input_size=input_size, 84 | scale_size=scale_size, normalize=normalize) 85 | 86 | 87 | 88 | 89 | class Lighting(object): 90 | """Lighting noise(AlexNet - style PCA - based noise)""" 91 | 92 | def __init__(self, alphastd, eigval, eigvec): 93 | self.alphastd = alphastd 94 | self.eigval = eigval 95 | self.eigvec = eigvec 96 | 97 | def __call__(self, img): 98 | if self.alphastd == 0: 99 | return img 100 | 101 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 102 | rgb = self.eigvec.type_as(img).clone()\ 103 | .mul(alpha.view(1, 3).expand(3, 3))\ 104 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 105 | .sum(1).squeeze() 106 | 107 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 108 | 109 | 110 | class Grayscale(object): 111 | 112 | def __call__(self, img): 113 | gs = img.clone() 114 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 115 | gs[1].copy_(gs[0]) 116 | gs[2].copy_(gs[0]) 117 | return gs 118 | 119 | 120 | class Saturation(object): 121 | 122 | def __init__(self, var): 123 | self.var = var 124 | 125 | def __call__(self, img): 126 | gs = Grayscale()(img) 127 | alpha = random.uniform(0, self.var) 128 | return img.lerp(gs, alpha) 129 | 130 | 131 | class Brightness(object): 132 | 133 | def __init__(self, var): 134 | self.var = var 135 | 136 | def __call__(self, img): 137 | gs = img.new().resize_as_(img).zero_() 138 | alpha = random.uniform(0, self.var) 139 | return img.lerp(gs, alpha) 140 | 141 | 142 | class Contrast(object): 143 | 144 | def __init__(self, var): 145 | self.var = var 146 | 147 | def __call__(self, img): 148 | gs = Grayscale()(img) 149 | gs.fill_(gs.mean()) 150 | alpha = random.uniform(0, self.var) 151 | return img.lerp(gs, alpha) 152 | 153 | 154 | class RandomOrder(object): 155 | """ Composes several transforms together in random order. 156 | """ 157 | 158 | def __init__(self, transforms): 159 | self.transforms = transforms 160 | 161 | def __call__(self, img): 162 | if self.transforms is None: 163 | return img 164 | order = torch.randperm(len(self.transforms)) 165 | for i in order: 166 | img = self.transforms[i](img) 167 | return img 168 | 169 | 170 | class ColorJitter(RandomOrder): 171 | 172 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 173 | self.transforms = [] 174 | if brightness != 0: 175 | self.transforms.append(Brightness(brightness)) 176 | if contrast != 0: 177 | self.transforms.append(Contrast(contrast)) 178 | if saturation != 0: 179 | self.transforms.append(Saturation(saturation)) 180 | -------------------------------------------------------------------------------- /dataloader/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import random 4 | 5 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 6 | 'std': [0.229, 0.224, 0.225]} 7 | 8 | #__imagenet_stats = {'mean': [0.5, 0.5, 0.5], 9 | # 'std': [0.5, 0.5, 0.5]} 10 | 11 | __imagenet_pca = { 12 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 13 | 'eigvec': torch.Tensor([ 14 | [-0.5675, 0.7192, 0.4009], 15 | [-0.5808, -0.0045, -0.8140], 16 | [-0.5836, -0.6948, 0.4203], 17 | ]) 18 | } 19 | 20 | 21 | def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats): 22 | t_list = [ 23 | transforms.ToTensor(), 24 | transforms.Normalize(**normalize), 25 | ] 26 | #if scale_size != input_size: 27 | #t_list = [transforms.Scale((960,540))] + t_list 28 | 29 | return transforms.Compose(t_list) 30 | 31 | 32 | def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 33 | t_list = [ 34 | transforms.RandomCrop(input_size), 35 | transforms.ToTensor(), 36 | transforms.Normalize(**normalize), 37 | ] 38 | if scale_size != input_size: 39 | t_list = [transforms.Scale(scale_size)] + t_list 40 | 41 | transforms.Compose(t_list) 42 | 43 | 44 | def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 45 | padding = int((scale_size - input_size) / 2) 46 | return transforms.Compose([ 47 | transforms.RandomCrop(input_size, padding=padding), 48 | transforms.RandomHorizontalFlip(), 49 | transforms.ToTensor(), 50 | transforms.Normalize(**normalize), 51 | ]) 52 | 53 | 54 | def inception_preproccess(input_size, normalize=__imagenet_stats): 55 | return transforms.Compose([ 56 | transforms.RandomSizedCrop(input_size), 57 | transforms.RandomHorizontalFlip(), 58 | transforms.ToTensor(), 59 | transforms.Normalize(**normalize) 60 | ]) 61 | def inception_color_preproccess(input_size, normalize=__imagenet_stats): 62 | return transforms.Compose([ 63 | #transforms.RandomSizedCrop(input_size), 64 | #transforms.RandomHorizontalFlip(), 65 | transforms.ToTensor(), 66 | ColorJitter( 67 | brightness=0.4, 68 | contrast=0.4, 69 | saturation=0.4, 70 | ), 71 | Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']), 72 | transforms.Normalize(**normalize) 73 | ]) 74 | 75 | 76 | def get_transform(name='imagenet', input_size=None, 77 | scale_size=None, normalize=None, augment=True): 78 | normalize = __imagenet_stats 79 | input_size = 256 80 | if augment: 81 | return inception_color_preproccess(input_size, normalize=normalize) 82 | else: 83 | return scale_crop(input_size=input_size, 84 | scale_size=scale_size, normalize=normalize) 85 | 86 | 87 | 88 | 89 | class Lighting(object): 90 | """Lighting noise(AlexNet - style PCA - based noise)""" 91 | 92 | def __init__(self, alphastd, eigval, eigvec): 93 | self.alphastd = alphastd 94 | self.eigval = eigval 95 | self.eigvec = eigvec 96 | 97 | def __call__(self, img): 98 | if self.alphastd == 0: 99 | return img 100 | 101 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 102 | rgb = self.eigvec.type_as(img).clone()\ 103 | .mul(alpha.view(1, 3).expand(3, 3))\ 104 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 105 | .sum(1).squeeze() 106 | 107 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 108 | 109 | 110 | class Grayscale(object): 111 | 112 | def __call__(self, img): 113 | gs = img.clone() 114 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 115 | gs[1].copy_(gs[0]) 116 | gs[2].copy_(gs[0]) 117 | return gs 118 | 119 | 120 | class Saturation(object): 121 | 122 | def __init__(self, var): 123 | self.var = var 124 | 125 | def __call__(self, img): 126 | gs = Grayscale()(img) 127 | alpha = random.uniform(0, self.var) 128 | return img.lerp(gs, alpha) 129 | 130 | 131 | class Brightness(object): 132 | 133 | def __init__(self, var): 134 | self.var = var 135 | 136 | def __call__(self, img): 137 | gs = img.new().resize_as_(img).zero_() 138 | alpha = random.uniform(0, self.var) 139 | return img.lerp(gs, alpha) 140 | 141 | 142 | class Contrast(object): 143 | 144 | def __init__(self, var): 145 | self.var = var 146 | 147 | def __call__(self, img): 148 | gs = Grayscale()(img) 149 | gs.fill_(gs.mean()) 150 | alpha = random.uniform(0, self.var) 151 | return img.lerp(gs, alpha) 152 | 153 | 154 | class RandomOrder(object): 155 | """ Composes several transforms together in random order. 156 | """ 157 | 158 | def __init__(self, transforms): 159 | self.transforms = transforms 160 | 161 | def __call__(self, img): 162 | if self.transforms is None: 163 | return img 164 | order = torch.randperm(len(self.transforms)) 165 | for i in order: 166 | img = self.transforms[i](img) 167 | return img 168 | 169 | 170 | class ColorJitter(RandomOrder): 171 | 172 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 173 | self.transforms = [] 174 | if brightness != 0: 175 | self.transforms.append(Brightness(brightness)) 176 | if contrast != 0: 177 | self.transforms.append(Contrast(contrast)) 178 | if saturation != 0: 179 | self.transforms.append(Saturation(saturation)) 180 | -------------------------------------------------------------------------------- /submission.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import torchvision.transforms as transforms 6 | import argparse 7 | np.set_printoptions(threshold=np.inf) 8 | import torch.nn.functional as F 9 | from PIL import Image 10 | import utils.logger as logger 11 | import time 12 | from models.LWANet import * 13 | 14 | 15 | parser = argparse.ArgumentParser(description='LWANet submission') 16 | 17 | parser = argparse.ArgumentParser(description='AnyNet with Flyingthings3d') 18 | parser.add_argument('--maxdisp', type=int, default=192, help='maxium disparity') 19 | parser.add_argument('--loss_weights', type=float, nargs='+', default=[1., 1.]) 20 | parser.add_argument('--maxdisplist', type=int, nargs='+', default=[24, 3, 3]) 21 | parser.add_argument('--lr', type=float, default=5e-4, help='learning rate') 22 | parser.add_argument('--with_cspn', type =bool, default= True, help='with cspn network or not') 23 | parser.add_argument('--datapath', default='/data6/wsgan/SenceFlow/train/', help='datapath') 24 | parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train') 25 | parser.add_argument('--train_bsize', type=int, default=8, help='batch size for training (default: 12)') 26 | parser.add_argument('--test_bsize', type=int, default=8, help='batch size for testing (default: 8)') 27 | parser.add_argument('--save_path', type=str, default='./results/kitti2015/benchmark', help='the path of saving checkpoints and log') 28 | parser.add_argument('--resume', type=str, default=None, help='resume path') 29 | parser.add_argument('--print_freq', type=int, default=400, help='print frequence') 30 | 31 | parser.add_argument('--model_types', type=str, default='original', help='model_types : LWANet_3D, mix, original') 32 | parser.add_argument('--conv_3d_types1', type=str, default='separate_only', help='model_types : 3D, P3D ') 33 | parser.add_argument('--conv_3d_types2', type=str, default='separate_only', help='model_types : 3D, P3D') 34 | parser.add_argument('--cost_volume', type=str, default='Difference', help='cost_volume type : "Concat" , "Difference" or "Distance_based" ') 35 | parser.add_argument('--train', type =bool, default=True, help='train or test ') 36 | 37 | 38 | parser.add_argument('--datapath2015', default='/data6/wsgan/KITTI/KITTI2015/testing/', help='datapath') 39 | parser.add_argument('--datapath2012', default='/data6/wsgan/KITTI/KITTI2012/testing/', help='datapath') 40 | parser.add_argument('--datatype', default='2015', help='finetune dataset: 2012, 2015') 41 | 42 | args = parser.parse_args() 43 | 44 | 45 | 46 | if args.datatype == '2015': 47 | from dataloader import KITTI_submission_loader as DA 48 | 49 | test_left_img, test_right_img = DA.dataloader2015(args.datapath2015) 50 | 51 | elif args.datatype == '2012': 52 | 53 | from dataloader import KITTI_submission_loader as DA 54 | test_left_img, test_right_img = DA.dataloader2012(args.datapath2012) 55 | 56 | else: 57 | 58 | AssertionError("None found datatype") 59 | 60 | 61 | 62 | 63 | log = logger.setup_logger(args.save_path + '/training.log') 64 | for key, value in sorted(vars(args).items()): 65 | log.info(str(key) + ': ' + str(value)) 66 | 67 | if args.pretrained: 68 | if os.path.isfile(args.pretrained): 69 | checkpoint = torch.load(args.pretrained) 70 | model.load_state_dict(checkpoint['state_dict'], strict=False) 71 | log.info('=> loaded pretrained model {}'.format(args.pretrained)) 72 | else: 73 | log.info('=> no pretrained model found at {}'.format(args.pretrained)) 74 | log.info("=> Will start from scratch.") 75 | 76 | 77 | else: 78 | log.info('Not Resume') 79 | 80 | 81 | 82 | model = LWANet(args) 83 | if args.cuda: 84 | model = nn.DataParallel(model) 85 | model.cuda() 86 | 87 | 88 | 89 | def test(imgL,imgR): 90 | 91 | model.eval() 92 | if args.cuda: 93 | imgL = imgL.cuda() 94 | imgR = imgR.cuda() 95 | 96 | with torch.no_grad(): 97 | disp, loss = model(imgL,imgR) 98 | 99 | disp = torch.squeeze(disp[-1]) 100 | #print('disp size:', disp.shape) 101 | pred_disp = disp.data.cpu().numpy() 102 | 103 | return pred_disp 104 | 105 | 106 | 107 | def main(): 108 | normal_mean_var = {'mean': [0.485, 0.456, 0.406], 109 | 'std': [0.229, 0.224, 0.225]} 110 | infer_transform = transforms.Compose([transforms.ToTensor(), 111 | transforms.Normalize(**normal_mean_var)]) 112 | 113 | total_inference_time = 0 114 | 115 | for inx in range(len(test_left_img)): 116 | 117 | imgL_o = Image.open(test_left_img[inx]).convert('RGB') 118 | imgR_o = Image.open(test_right_img[inx]).convert('RGB') 119 | 120 | 121 | imgL = infer_transform(imgL_o) 122 | imgR = infer_transform(imgR_o) 123 | 124 | 125 | # pad to width and hight to 16 times 126 | if imgL.shape[1] % 16 != 0: 127 | times = imgL.shape[1]//16 128 | top_pad = (times+1)*16 -imgL.shape[1] 129 | else: 130 | top_pad = 0 131 | 132 | if imgL.shape[2] % 16 != 0: 133 | times = imgL.shape[2]//16 134 | right_pad = (times+1)*16-imgL.shape[2] 135 | else: 136 | right_pad = 0 137 | 138 | imgL = F.pad(imgL,(0,right_pad, top_pad,0)).unsqueeze(0) 139 | imgR = F.pad(imgR,(0,right_pad, top_pad,0)).unsqueeze(0) 140 | 141 | start_time = time.time() 142 | pred_disp = test(imgL,imgR) 143 | 144 | total_inference_time += time.time() - start_time 145 | 146 | if top_pad !=0 or right_pad != 0: 147 | img = pred_disp[top_pad:,:-right_pad] 148 | else: 149 | img = pred_disp 150 | 151 | img = (img*256).astype('uint16') 152 | img = Image.fromarray(img) 153 | print("inx:", inx) 154 | img.save(args.save_path + test_left_img[inx].split('/')[-1]) 155 | 156 | 157 | log.info("mean inference time: %.3fs " % (total_inference_time/len(test_left_img))) 158 | 159 | log.info("finish {} images inference".format(len(test_left_img))) 160 | 161 | 162 | 163 | if __name__ == '__main__': 164 | main() 165 | 166 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class self_supervised_loss (nn.modules.Module): 6 | def __init__(self, n=1, SSIM_w=0.85, disp_gradient_w=1.0, lr_w=1.0): 7 | super(self_supervised_loss, self).__init__() 8 | self.SSIM_w = SSIM_w 9 | self.disp_gradient_w = disp_gradient_w 10 | self.lr_w = lr_w 11 | self.n = n 12 | 13 | def scale_pyramid(self, img): 14 | scaled_imgs = [img] 15 | 16 | return scaled_imgs 17 | 18 | def gradient_x(self, img): 19 | # Pad input to keep output size consistent 20 | img = F.pad(img, (0, 1, 0, 0), mode="replicate") 21 | gx = img[:, :, :, :-1] - img[:, :, :, 1:] # NCHW 22 | return gx 23 | 24 | def gradient_y(self, img): 25 | # Pad input to keep output size consistent 26 | img = F.pad(img, (0, 0, 0, 1), mode="replicate") 27 | gy = img[:, :, :-1, :] - img[:, :, 1:, :] # NCHW 28 | return gy 29 | 30 | 31 | 32 | def apply_disparity(slef, img, disp, cuda=True): 33 | ''' 34 | img.shape = b, c, h, w 35 | disp.shape = b, h, w 36 | ''' 37 | b, c, h, w = img.shape 38 | disp = disp.squeeze(1) 39 | 40 | if cuda == True: 41 | right_coor_x = (torch.arange(start=0, end=w, out=torch.cuda.FloatTensor())).repeat(b, h, 1) 42 | right_coor_y = (torch.arange(start=0, end=h, out=torch.cuda.FloatTensor())).repeat(b, w, 1).transpose(1, 2) 43 | else: 44 | right_coor_x = (torch.arange(start=0, end=w, out=torch.FloatTensor())).repeat(b, h, 1) 45 | right_coor_y = (torch.arange(start=0, end=h, out=torch.FloatTensor())).repeat(b, w, 1).transpose(1, 2) 46 | left_coor_x1 = right_coor_x + disp 47 | left_coor_norm1 = torch.stack((left_coor_x1 / (w - 1) * 2 - 1, right_coor_y / (h - 1) * 2 - 1), dim=1) 48 | ## backward warp 49 | warp_img = torch.nn.functional.grid_sample(img, left_coor_norm1.permute(0, 2, 3, 1)) 50 | 51 | return warp_img 52 | 53 | def generate_image_left(self, img, disp): 54 | return self.apply_disparity(img, -disp) 55 | 56 | def generate_image_right(self, img, disp): 57 | return self.apply_disparity(img, disp) 58 | 59 | def SSIM(self, x, y): 60 | C1 = 0.01 ** 2 61 | C2 = 0.03 ** 2 62 | 63 | mu_x = nn.AvgPool2d(3, 1)(x) 64 | mu_y = nn.AvgPool2d(3, 1)(y) 65 | mu_x_mu_y = mu_x * mu_y 66 | mu_x_sq = mu_x.pow(2) 67 | mu_y_sq = mu_y.pow(2) 68 | 69 | sigma_x = nn.AvgPool2d(3, 1)(x * x) - mu_x_sq 70 | sigma_y = nn.AvgPool2d(3, 1)(y * y) - mu_y_sq 71 | sigma_xy = nn.AvgPool2d(3, 1)(x * y) - mu_x_mu_y 72 | 73 | SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2) 74 | SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2) 75 | SSIM = SSIM_n / SSIM_d 76 | 77 | return torch.clamp((1 - SSIM) / 2, 0, 1) 78 | 79 | def SSIM_WEIGHT(self, x, y): 80 | C1 = 0.01 ** 2 81 | C2 = 0.03 ** 2 82 | 83 | mu_x = nn.AvgPool2d(3, 1)(x) 84 | mu_y = nn.AvgPool2d(3, 1)(y) 85 | mu_x_mu_y = mu_x * mu_y 86 | mu_x_sq = mu_x.pow(2) 87 | mu_y_sq = mu_y.pow(2) 88 | 89 | sigma_x = nn.AvgPool2d(3, 1)(x * x) - mu_x_sq 90 | sigma_y = nn.AvgPool2d(3, 1)(y * y) - mu_y_sq 91 | sigma_xy = nn.AvgPool2d(3, 1)(x * y) - mu_x_mu_y 92 | 93 | SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2) 94 | SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2) 95 | SSIM = SSIM_n / SSIM_d 96 | 97 | return torch.clamp((SSIM) / 2, 0, 1) 98 | 99 | def disp_smoothness(self, disp, pyramid): 100 | disp_gradients_x = [self.gradient_x(d) for d in disp] 101 | disp_gradients_y = [self.gradient_y(d) for d in disp] 102 | 103 | image_gradients_x = [self.gradient_x(img) for img in pyramid] 104 | image_gradients_y = [self.gradient_y(img) for img in pyramid] 105 | 106 | weights_x = [torch.exp(-torch.mean(torch.abs(g), 1, 107 | keepdim=True)) for g in image_gradients_x] 108 | weights_y = [torch.exp(-torch.mean(torch.abs(g), 1, 109 | keepdim=True)) for g in image_gradients_y] 110 | 111 | smoothness_x = [disp_gradients_x[i] * weights_x[i] 112 | for i in range(self.n)] 113 | smoothness_y = [disp_gradients_y[i] * weights_y[i] 114 | for i in range(self.n)] 115 | 116 | return [torch.abs(smoothness_x[i]) + torch.abs(smoothness_y[i]) 117 | for i in range(self.n)] 118 | 119 | 120 | def reconstruction_image_first_order_gradient(self, left_est, left_pyramid): 121 | 122 | RI_x = [self.gradient_x(d) for d in left_est] 123 | RI_y = [self.gradient_y(d) for d in left_est] 124 | 125 | OI_x = [self.gradient_x(d) for d in left_pyramid] 126 | OI_y = [self.gradient_y(d) for d in left_pyramid] 127 | 128 | fisrt_order_loss = [torch.mean(torch.abs(RI_x[i] - OI_x[i])) + torch.mean(torch.abs(RI_y[i] - OI_y[i])) 129 | for i in range(self.n)] 130 | 131 | return fisrt_order_loss 132 | 133 | def forward(self, input, target): 134 | """ 135 | Args: 136 | input [disp1, disp2] 137 | target [left, right] 138 | 139 | Return: 140 | (float): The loss 141 | """ 142 | 143 | left, right = target 144 | 145 | 146 | disp_left_est = [input[:, 0, :, :].unsqueeze(1) ] 147 | 148 | left_pyramid = self.scale_pyramid(left) 149 | right_pyramid = self.scale_pyramid(right) 150 | 151 | # Generate images 152 | left_est = [self.generate_image_left(right_pyramid[0], 153 | disp_left_est[0]) ] 154 | 155 | 156 | # Disparities smoothness 157 | disp_left_smoothness = self.disp_smoothness(disp_left_est, left_pyramid) 158 | l1_left = [torch.mean(torch.abs(left_est[0] - left_pyramid[0]))] 159 | ssim_left = [torch.mean(self.SSIM(left_est[0], left_pyramid[0])) ] 160 | image_loss_left = [self.SSIM_w * ssim_left[0] + (1 - self.SSIM_w) * (l1_left[0] )] 161 | 162 | image_loss = sum(image_loss_left) 163 | 164 | # Disparities smoothness 165 | disp_left_loss = [torch.mean(torch.abs(disp_left_smoothness[0])) ] 166 | disp_gradient_loss = sum(disp_left_loss) 167 | 168 | loss = image_loss + self.disp_gradient_w * disp_gradient_loss 169 | 170 | return loss 171 | -------------------------------------------------------------------------------- /models/Aggregation_submodules.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | from __future__ import print_function 3 | import torch.nn as nn 4 | import math 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def activation_function(types = "ELU"): # ELU or Relu 10 | 11 | 12 | if types == "ELU": 13 | 14 | return nn.Sequential(nn.ELU(inplace=True)) 15 | 16 | elif types == "Mish": 17 | 18 | nn.Sequential(Mish()) 19 | 20 | elif types == "Relu": 21 | 22 | return nn.Sequential(nn.ReLU(inplace=True)) 23 | 24 | else: 25 | 26 | AssertionError("please define the activate function types") 27 | 28 | 29 | 30 | 31 | class Mish(nn.Module): 32 | ''' 33 | Applies the mish function element-wise: 34 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) 35 | Shape: 36 | - Input: (N, *) where * means, any number of additional 37 | dimensions 38 | - Output: (N, *), same shape as the input 39 | Examples: 40 | 41 | ''' 42 | def __init__(self): 43 | ''' 44 | Init method. 45 | ''' 46 | super().__init__() 47 | 48 | def forward(self, input): 49 | ''' 50 | Forward pass of the function. 51 | ''' 52 | return input * torch.tanh(F.softplus(input)) 53 | 54 | 55 | 56 | # cost aggregation submodule 57 | 58 | def conv_3d(in_planes, out_planes, kernel_size, stride, pad, conv_3d_types="3D"): 59 | 60 | 61 | if conv_3d_types == "3D": 62 | 63 | return nn.Sequential( 64 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, padding=pad, stride=stride, bias=False) 65 | ) 66 | 67 | 68 | elif conv_3d_types == "P3D": # 3*3*3 to 1*3*3 + 3*1*1 69 | 70 | return nn.Sequential( 71 | 72 | nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride, padding=(0, 1, 1), bias=False), 73 | nn.ReLU(inplace=True), 74 | nn.Conv3d(out_planes, out_planes, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False), 75 | 76 | ) 77 | 78 | 79 | else: 80 | 81 | AssertionError("please define conv_3d_types") 82 | 83 | 84 | 85 | 86 | def convbn_3d(in_planes, out_planes, kernel_size, stride, pad, conv_3d_types="3D"): 87 | 88 | 89 | if conv_3d_types == "3D": 90 | 91 | return nn.Sequential( 92 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, padding=pad, stride=stride, bias=False), 93 | nn.BatchNorm3d(out_planes)) 94 | 95 | 96 | elif conv_3d_types == "P3D": # 3*3*3 to 1*3*3 + 3*1*1 97 | 98 | return nn.Sequential( 99 | 100 | nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride, padding=(0, 1, 1), bias=False), 101 | nn.ReLU(inplace=True), 102 | nn.Conv3d(out_planes, out_planes, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False), 103 | 104 | nn.BatchNorm3d(out_planes)) 105 | 106 | 107 | else: 108 | 109 | AssertionError("please define conv_3d_types") 110 | 111 | 112 | 113 | def convTranspose3d(in_planes, out_planes, kernel_size, stride, padding=1, conv_3d_types="P3D"): 114 | 115 | if conv_3d_types == '3D': 116 | return nn.Sequential( 117 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size, padding = padding, output_padding=1, stride=stride, bias=False), 118 | nn.BatchNorm3d(out_planes)) 119 | 120 | 121 | elif conv_3d_types == "P3D": 122 | 123 | return nn.Sequential( 124 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size, padding=padding, output_padding=1, stride=stride, 125 | bias=False), 126 | nn.BatchNorm3d(out_planes)) 127 | 128 | 129 | else: 130 | AssertionError("please define conv_3d_types") 131 | 132 | 133 | 134 | 135 | class LWANet_Aggregation(nn.Module): # base on PSMNet basic 136 | def __init__(self, input_planes=8, planes=16, maxdisp=192, conv_3d_types1 = "P3D", conv_3d_types2 = "P3D", activation_types2 = "ELU"): 137 | super(LWANet_Aggregation, self).__init__() 138 | self.maxdisp = maxdisp 139 | 140 | self.pre_3D = nn.Sequential( 141 | convbn_3d(input_planes, planes, 3, 1, 1, conv_3d_types = conv_3d_types1), 142 | activation_function(types = activation_types2), 143 | convbn_3d(planes, planes, 3, 2, 1, conv_3d_types = conv_3d_types1), 144 | activation_function(types = activation_types2) 145 | ) 146 | 147 | self.middle_3D = nn.Sequential( 148 | 149 | convbn_3d(planes, planes*2, 3, 1, 1, conv_3d_types = conv_3d_types2), 150 | activation_function(types = activation_types2), 151 | convbn_3d(planes*2, planes*4, 3, 1, 1, conv_3d_types = conv_3d_types2), 152 | activation_function(types = activation_types2), 153 | convbn_3d(planes * 4, planes * 4, 3, 1, 1, conv_3d_types=conv_3d_types2), 154 | activation_function(types=activation_types2), 155 | convbn_3d(planes * 4, planes * 2, 3, 1, 1, conv_3d_types=conv_3d_types2), 156 | activation_function(types=activation_types2), 157 | convTranspose3d(planes * 2, planes * 2, kernel_size=3, stride=2, conv_3d_types=conv_3d_types2), 158 | activation_function(types=activation_types2) 159 | ) 160 | 161 | self.post_3D = nn.Sequential( 162 | convbn_3d(planes*2, planes, 3, 1, 1, conv_3d_types = conv_3d_types1), 163 | activation_function(types = activation_types2), 164 | conv_3d(planes, 1, kernel_size=3, pad=1, stride=1, conv_3d_types = conv_3d_types1) 165 | ) 166 | 167 | 168 | for m in self.modules(): 169 | if isinstance(m, nn.Conv2d): 170 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 171 | m.weight.data.normal_(0, math.sqrt(2. / n)) 172 | elif isinstance(m, nn.Conv3d): 173 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 174 | m.weight.data.normal_(0, math.sqrt(2. / n)) 175 | elif isinstance(m, nn.BatchNorm2d): 176 | m.weight.data.fill_(1) 177 | m.bias.data.zero_() 178 | elif isinstance(m, nn.BatchNorm3d): 179 | m.weight.data.fill_(1) 180 | m.bias.data.zero_() 181 | elif isinstance(m, nn.Linear): 182 | m.bias.data.zero_() 183 | 184 | def forward(self, cost): 185 | 186 | cost = self.pre_3D(cost) 187 | cost = self.middle_3D(cost) 188 | cost = self.post_3D(cost) 189 | 190 | 191 | return cost 192 | 193 | 194 | 195 | 196 | -------------------------------------------------------------------------------- /models/LWANet.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | from __future__ import print_function 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.data 6 | from torch.autograd import Variable 7 | import torch.nn.functional as F 8 | import math 9 | from .cspn import Affinity_Propagate 10 | from .feature_extraction import F1, F2, F3, F2_UP, F3_UP , F1_UP 11 | from .Aggregation_submodules import LWANet_Aggregation 12 | from .cost import _build_cost_volume 13 | from .loss import self_supervised_loss 14 | 15 | 16 | class LWANet(nn.Module): 17 | def __init__(self, args): 18 | super(LWANet, self).__init__() 19 | 20 | #self.init_channels = args.init_channels 21 | self.maxdisplist = args.maxdisplist 22 | self.with_cspn = args.with_cspn 23 | self.model_types =args.model_types # "LWANet: 3D orP3D 24 | self.conv_3d_types1 = args.conv_3d_types1 25 | self.conv_3d_types2 = args.conv_3d_types2 26 | self.cost_volume = args.cost_volume 27 | self.maxdisp = args.maxdisp 28 | 29 | 30 | self.F1 = F1() 31 | self.F2 = F2() 32 | self.F3 = F3() 33 | 34 | self.F1_CSPN = F1() 35 | self.F2_CSPN = F2() 36 | self.F3_CSPN = F3() 37 | 38 | if self.cost_volume =="Distance_based": 39 | 40 | self.volume_postprocess = LWANet_Aggregation( input_planes=1, planes=8, 41 | conv_3d_types1 = self.conv_3d_types1, 42 | conv_3d_types2 = self.conv_3d_types2) 43 | 44 | elif self.cost_volume =="Difference": 45 | self.volume_postprocess = LWANet_Aggregation(input_planes=16, planes=12, 46 | conv_3d_types1=self.conv_3d_types1, 47 | conv_3d_types2=self.conv_3d_types2) 48 | 49 | elif self.cost_volume =="Concat": 50 | self.volume_postprocess = LWANet_Aggregation(input_planes=32, planes=12, 51 | conv_3d_types1=self.conv_3d_types1, 52 | conv_3d_types2=self.conv_3d_types2) 53 | 54 | if self.with_cspn: 55 | 56 | self.F2_UP = F2_UP() 57 | self.F3_UP = F3_UP() 58 | self.F1_UP = F1_UP() 59 | 60 | cspn_config_default = {'step':4, 'kernel': 3, 'norm_type': '8sum'} 61 | self.post_process_layer = [self._make_post_process_layer(cspn_config_default)] 62 | self.post_process_layer = nn.ModuleList(self.post_process_layer) 63 | 64 | self.self_supervised_loss = self_supervised_loss(n=1, SSIM_w=0.85, disp_gradient_w=0.1, lr_w=1) 65 | 66 | for m in self.modules(): 67 | if isinstance(m, nn.Conv2d): 68 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 69 | m.weight.data.normal_(0, math.sqrt(2. / n)) 70 | elif isinstance(m, nn.Conv3d): 71 | n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels 72 | m.weight.data.normal_(0, math.sqrt(2. / n)) 73 | elif isinstance(m, nn.BatchNorm2d): 74 | m.weight.data.fill_(1) 75 | m.bias.data.zero_() 76 | elif isinstance(m, nn.BatchNorm3d): 77 | m.weight.data.fill_(1) 78 | m.bias.data.zero_() 79 | elif isinstance(m, nn.Linear): 80 | m.bias.data.zero_() 81 | 82 | 83 | def _make_post_process_layer(self, cspn_config=None): 84 | return Affinity_Propagate(cspn_config['step'], 85 | cspn_config['kernel'], 86 | norm_type=cspn_config['norm_type']) 87 | 88 | def forward(self, left, right): 89 | 90 | 91 | img_size = left.size() 92 | 93 | feats_l_F1 = self.F1(left) 94 | 95 | feats_l_F2 = self.F2(feats_l_F1) 96 | 97 | feats_l_F3 = self.F3(feats_l_F2) 98 | 99 | 100 | feats_l = feats_l_F3 101 | 102 | 103 | feats_r_F1 = self.F1(right) 104 | 105 | feats_r_F2 = self.F2(feats_r_F1) 106 | 107 | feats_r_F3 = self.F3(feats_r_F2) 108 | 109 | feats_r = feats_r_F3 110 | 111 | 112 | pred = [] 113 | 114 | cost = _build_cost_volume(self.cost_volume, feats_l, feats_r, self.maxdisp) 115 | 116 | cost = self.volume_postprocess(cost).squeeze(1) 117 | 118 | 119 | pred_low_res_left = disparityregression2(0, self.maxdisplist[0])(F.softmax(-cost, dim=1)) 120 | 121 | pred_low_res = pred_low_res_left * img_size[2] / pred_low_res_left.size(2) 122 | 123 | disp_up = F.upsample(pred_low_res, (img_size[2], img_size[3]), mode='bilinear') 124 | 125 | pred.append(disp_up) 126 | 127 | 128 | if self.with_cspn: 129 | 130 | feats_l_F1_CSPN = self.F1_CSPN(left) 131 | 132 | feats_l_F2_CSPN = self.F2_CSPN(feats_l_F1_CSPN) 133 | 134 | feats_l_F3_CSPN = self.F3_CSPN(feats_l_F2_CSPN) 135 | 136 | 137 | 138 | F3_UP = torch.cat((self.F3_UP(feats_l_F3_CSPN), feats_l_F2_CSPN), 1) 139 | 140 | F2_UP = torch.cat((self.F2_UP(F3_UP), feats_l_F1_CSPN), 1) 141 | 142 | F1_UP = self.F1_UP(F2_UP) 143 | 144 | x = self.post_process_layer[0](F1_UP, disp_up) 145 | 146 | pred.append(x) 147 | 148 | loss = [] 149 | 150 | 151 | if self.train: 152 | for outputs in pred: 153 | loss.append(self.self_supervised_loss(outputs, [left, right])) 154 | 155 | else: 156 | loss = [0] 157 | 158 | pred = [torch.squeeze(pred, 1) for pred in pred] 159 | 160 | return pred, loss 161 | 162 | 163 | 164 | class disparityregression2(nn.Module): 165 | def __init__(self, start, end, stride=1): 166 | super(disparityregression2, self).__init__() 167 | self.disp = Variable(torch.arange(start*stride, end*stride, stride).view(1, -1, 1, 1).cuda(), requires_grad=False) 168 | 169 | def forward(self, x): 170 | disp = self.disp.repeat(x.size()[0], 1, x.size()[2], x.size()[3]) 171 | disp = disp.float() 172 | 173 | out = torch.sum(x*disp, 1, keepdim=True) 174 | return out 175 | 176 | 177 | 178 | class L1Loss(object): 179 | def __call__(self, input, target): 180 | return torch.abs(input - target).mean() 181 | 182 | 183 | 184 | def apply_disparity(img, disp): 185 | 186 | batch_size, _, height, width = img.size() 187 | 188 | # Original coordinates of pixels 189 | x_base = torch.linspace(0, 1, width).repeat(batch_size, 190 | height, 1).type_as(img) 191 | y_base = torch.linspace(0, 1, height).repeat(batch_size, 192 | width, 1).transpose(1, 2).type_as(img) 193 | 194 | # Apply shift in X direction 195 | x_shifts = disp[:, 0, :, :] # Disparity is passed in NCHW format with 1 channel 196 | flow_field = torch.stack((x_base + x_shifts, y_base), dim=3) 197 | # In grid_sample coordinates are assumed to be between -1 and 1 198 | output = F.grid_sample(img, 2*flow_field - 1, mode='bilinear', 199 | padding_mode='zeros') 200 | 201 | return output 202 | 203 | 204 | 205 | def generate_image_left( img, disp): 206 | return apply_disparity(img, -disp) 207 | 208 | 209 | -------------------------------------------------------------------------------- /utils/flops_hook.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.conv import _ConvNd 4 | multiply_adds = 1 5 | 6 | 7 | def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): 8 | x = x[0] 9 | 10 | kernel_ops = m.weight.size()[2:].numel() # Kw x Kh 11 | 12 | # N x Cout x H x W x (Cin x Kw x Kh + bias) 13 | total_ops = y.nelement() * ( 14 | m.in_channels // m.groups * kernel_ops) 15 | 16 | m.total_ops += torch.Tensor([int(total_ops)]) 17 | 18 | 19 | 20 | 21 | def count_bn(m, x, y): 22 | x = x[0] 23 | 24 | nelements = x.numel() 25 | if not m.training: 26 | # subtract, divide, gamma, beta 27 | total_ops = 2 * nelements 28 | 29 | m.total_ops += torch.Tensor([int(total_ops)]) 30 | 31 | 32 | 33 | 34 | def count_relu(m, x, y): 35 | x = x[0] 36 | 37 | nelements = x.numel() 38 | 39 | m.total_ops += torch.Tensor([int(nelements)]) 40 | 41 | 42 | def count_softmax(m, x, y): 43 | x = x[0] 44 | 45 | batch_size, nfeatures = x.size() 46 | 47 | total_exp = nfeatures 48 | total_add = nfeatures - 1 49 | total_div = nfeatures 50 | total_ops = batch_size * (total_exp + total_add + total_div) 51 | 52 | m.total_ops += torch.Tensor([int(total_ops)]) 53 | 54 | 55 | def count_avgpool(m, x, y): 56 | # total_add = torch.prod(torch.Tensor([m.kernel_size])) 57 | # total_div = 1 58 | # kernel_ops = total_add + total_div 59 | kernel_ops = 1 60 | num_elements = y.numel() 61 | total_ops = kernel_ops * num_elements 62 | 63 | m.total_ops += torch.Tensor([int(total_ops)]) 64 | 65 | 66 | def count_adap_avgpool(m, x, y): 67 | kernel = torch.Tensor( 68 | [*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze() 69 | total_add = torch.prod(kernel) 70 | total_div = 1 71 | kernel_ops = total_add + total_div 72 | num_elements = y.numel() 73 | total_ops = kernel_ops * num_elements 74 | 75 | m.total_ops += torch.Tensor([int(total_ops)]) 76 | 77 | 78 | # TODO: verify the accuracy 79 | def count_upsample(m, x, y): 80 | if m.mode not in ("nearest", "linear", "bilinear", "bicubic",): # "trilinear" 81 | logger.warning( 82 | "mode %s is not implemented yet, take it a zero op" % m.mode) 83 | return zero_ops(m, x, y) 84 | 85 | if m.mode == "nearest": 86 | return zero_ops(m, x, y) 87 | 88 | x = x[0] 89 | if m.mode == "linear": 90 | total_ops = y.nelement() * 5 # 2 muls + 3 add 91 | elif m.mode == "bilinear": 92 | # https://en.wikipedia.org/wiki/Bilinear_interpolation 93 | total_ops = y.nelement() * 11 # 6 muls + 5 adds 94 | elif m.mode == "bicubic": 95 | # https://en.wikipedia.org/wiki/Bicubic_interpolation 96 | # Product matrix [4x4] x [4x4] x [4x4] 97 | ops_solve_A = 224 # 128 muls + 96 adds 98 | ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds 99 | total_ops = y.nelement() * (ops_solve_A + ops_solve_p) 100 | elif m.mode == "trilinear": 101 | # https://en.wikipedia.org/wiki/Trilinear_interpolation 102 | # can viewed as 2 bilinear + 1 linear 103 | total_ops = y.nelement() * (13 * 2 + 5) 104 | 105 | m.total_ops += torch.Tensor([int(total_ops)]) 106 | 107 | 108 | def count_linear(m, x, y): 109 | # per output element 110 | total_mul = m.in_features 111 | total_add = m.in_features - 1 112 | total_add += 1 if m.bias is not None else 0 113 | num_elements = y.numel() 114 | total_ops = (total_mul + total_add) * num_elements 115 | 116 | m.total_ops += torch.Tensor([int(total_ops)]) 117 | 118 | 119 | def zero_ops(m, x, y): 120 | m.total_ops += torch.Tensor([int(0)]) 121 | 122 | 123 | 124 | register_hooks = { 125 | nn.Conv1d: count_convNd, 126 | nn.Conv2d: count_convNd, 127 | nn.Conv3d: count_convNd, 128 | nn.ConvTranspose1d: count_convNd, 129 | nn.ConvTranspose2d: count_convNd, 130 | nn.ConvTranspose3d: count_convNd, 131 | 132 | nn.BatchNorm1d: count_bn, 133 | nn.BatchNorm2d: count_bn, 134 | nn.BatchNorm3d: count_bn, 135 | 136 | nn.ReLU: zero_ops, 137 | nn.ReLU6: zero_ops, 138 | nn.LeakyReLU: count_relu, 139 | 140 | nn.MaxPool1d: zero_ops, 141 | nn.MaxPool2d: zero_ops, 142 | nn.MaxPool3d: zero_ops, 143 | nn.AdaptiveMaxPool1d: zero_ops, 144 | nn.AdaptiveMaxPool2d: zero_ops, 145 | nn.AdaptiveMaxPool3d: zero_ops, 146 | 147 | nn.AvgPool1d: count_avgpool, 148 | nn.AvgPool2d: count_avgpool, 149 | nn.AvgPool3d: count_avgpool, 150 | nn.AdaptiveAvgPool1d: count_adap_avgpool, 151 | nn.AdaptiveAvgPool2d: count_adap_avgpool, 152 | nn.AdaptiveAvgPool3d: count_adap_avgpool, 153 | 154 | nn.Linear: count_linear, 155 | nn.Dropout: zero_ops, 156 | 157 | nn.Upsample: count_upsample, 158 | nn.UpsamplingBilinear2d: count_upsample, 159 | nn.UpsamplingNearest2d: count_upsample 160 | } 161 | 162 | 163 | 164 | 165 | 166 | def profile(model, inputs, custom_ops=None, verbose=True): 167 | handler_collection = [] 168 | if custom_ops is None: 169 | custom_ops = {} 170 | 171 | def add_hooks(m): 172 | if len(list(m.children())) > 0: 173 | return 174 | 175 | # if hasattr(m, "total_ops") or hasattr(m, "total_params"): 176 | # raise Warning("Either .total_ops or .total_params is already defined in %s.\n" 177 | # "Be careful, it might change your code's behavior." % str(m)) 178 | 179 | m.register_buffer('total_ops', torch.zeros(1)) 180 | m.register_buffer('total_params', torch.zeros(1)) 181 | 182 | for p in m.parameters(): 183 | m.total_params += torch.Tensor([p.numel()]) 184 | 185 | m_type = type(m) 186 | fn = None 187 | if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite. 188 | fn = custom_ops[m_type] 189 | elif m_type in register_hooks: 190 | fn = register_hooks[m_type] 191 | 192 | if fn is None: 193 | if verbose: 194 | print("THOP has not implemented counting method for", m) 195 | else: 196 | if verbose: 197 | print("Register FLOP counter for module %s" % str(m)) 198 | handler = m.register_forward_hook(fn) 199 | handler_collection.append(handler) 200 | 201 | # original_device = model.parameters().__next__().device 202 | training = model.training 203 | 204 | model.eval() 205 | model.apply(add_hooks) 206 | 207 | with torch.no_grad(): 208 | model(*inputs) 209 | 210 | total_ops = 0 211 | total_params = 0 212 | _temp = [] 213 | for m in model.modules(): 214 | if len(list(m.children())) > 0: # skip for non-leaf module 215 | continue 216 | total_ops += m.total_ops 217 | total_params += m.total_params 218 | _temp.append(m.total_ops.item()) 219 | 220 | total_ops = total_ops.item() 221 | total_params = total_params.item() 222 | 223 | # reset model to original status 224 | model.train(training) 225 | for handler in handler_collection: 226 | handler.remove() 227 | 228 | return clever_format(total_ops), clever_format(total_params) 229 | 230 | 231 | def clever_format(num): 232 | # if num > 1e12: 233 | # return "%.2f" % (num / 1e12) + "T" 234 | if num > 1e9: 235 | return "%.2f" % (num / 1e9) + "G" 236 | if num > 1e6: 237 | return "%.2f" % (num / 1e6) + "M" 238 | if num > 1e3: 239 | return "%.2f" % (num / 1e3) + "K" 240 | -------------------------------------------------------------------------------- /models/cost.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | 6 | def _build_volume_2d_anynet(feat_l, feat_r, maxdisp, stride=1): 7 | 8 | assert maxdisp % stride == 0 # Assume maxdisp is multiple of stride 9 | cost = torch.zeros((feat_l.size()[0], maxdisp//stride, feat_l.size()[2], feat_l.size()[3]), device='cuda') 10 | for i in range(0, maxdisp, stride): 11 | cost[:, i// stride, :, :i] = feat_l[:, :, :, :i].abs().sum(1) 12 | 13 | if i > 0: 14 | cost[:, i // stride, :, i:] = torch.norm(feat_l[:, :, :, i:] - feat_r[:, :, :, :-i], 1, 1) 15 | else: 16 | cost[:, i // stride, :, i:] = torch.norm(feat_l[:, :, :, :] - feat_r[:, :, :, :], 1, 1) 17 | 18 | return cost.contiguous() 19 | 20 | 21 | def _build_volume_2d3_anynet( feat_l, feat_r, disp, maxdisp=3, stride=1): 22 | size = feat_l.size() 23 | batch_disp = disp[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, 1, size[-2], size[-1]) 24 | batch_shift = torch.arange(-maxdisp + 1, maxdisp, device='cuda').repeat(size[0])[:, None, None, None] * stride 25 | batch_disp = batch_disp - batch_shift.float() 26 | batch_feat_l = feat_l[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, size[-3], size[-2], size[-1]) 27 | batch_feat_r = feat_r[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, size[-3], size[-2], size[-1]) 28 | 29 | cost = torch.norm(batch_feat_l - warp(batch_feat_r, batch_disp), 1, 1) 30 | 31 | cost = cost.view(size[0], -1, size[2], size[3]) 32 | 33 | return cost.contiguous() 34 | 35 | 36 | def _build_volume_2d_psmnet( refimg_fea, targetimg_fea, maxdisp): 37 | cost = Variable( 38 | torch.FloatTensor(refimg_fea.size()[0], refimg_fea.size()[1] * 2, maxdisp , refimg_fea.size()[2], 39 | refimg_fea.size()[3]).zero_()).cuda() 40 | 41 | for i in range(maxdisp): 42 | if i > 0: 43 | 44 | cost[:, :refimg_fea.size()[1], i, :, i:] = refimg_fea[:, :, :, i:] 45 | cost[:, refimg_fea.size()[1]:, i, :, i:] = targetimg_fea[:, :, :, :-i] 46 | else: 47 | cost[:, :refimg_fea.size()[1], i, :, :] = refimg_fea 48 | cost[:, refimg_fea.size()[1]:, i, :, :] = targetimg_fea 49 | 50 | return cost.contiguous() 51 | 52 | 53 | 54 | 55 | def _build_volume_2d3_psmnet(feat_l, feat_r, disp, maxdisp=3, stride=1): 56 | size = feat_l.size() 57 | 58 | batch_disp = disp[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, 1, size[-2], size[-1]) 59 | batch_shift = torch.arange(-maxdisp + 1, maxdisp, device='cuda').repeat(size[0])[:, None, None, None] * stride 60 | batch_disp = batch_disp - batch_shift.float() 61 | batch_feat_l = feat_l[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, size[-3], size[-2], size[-1]) 62 | batch_feat_r = feat_r[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, size[-3], size[-2], 63 | size[-1]) 64 | #cost = batch_feat_l - warp(batch_feat_r, batch_disp) 65 | cost = torch.cat((batch_feat_l , warp(batch_feat_r, batch_disp)), 1).contiguous() 66 | 67 | cost = cost.view(size[0], size[1]*2, -1, size[2], size[3]) 68 | # print("cost size", cost.shape) 69 | 70 | return cost.contiguous() 71 | 72 | 73 | 74 | 75 | def _build_volume_2d_aanet(refimg_fea, targetimg_fea, maxdisp): 76 | 77 | 78 | b, c, h, w = refimg_fea.size() 79 | cost_volume = refimg_fea.new_zeros(b, maxdisp, h, w) 80 | 81 | for i in range(maxdisp): 82 | if i > 0: 83 | cost_volume[:, i, :, i:] = (refimg_fea[:, :, :, i:] * 84 | targetimg_fea[:, :, :, :-i]).mean(dim=1) 85 | else: 86 | cost_volume[:, i, :, :] = (refimg_fea * targetimg_fea).mean(dim=1) 87 | 88 | return cost_volume.contiguous() 89 | 90 | 91 | 92 | 93 | 94 | def _build_volume_2d_difference(feat_l, feat_r, maxdisp): 95 | 96 | b, c, h, w = feat_l.size() 97 | 98 | 99 | cost_volume = feat_l.new_zeros(b, c, maxdisp, h, w) 100 | 101 | for i in range(maxdisp): 102 | if i > 0: 103 | cost_volume[:, :, i, :, i:] = feat_l[:, :, :, i:] - feat_r[:, :, :, :-i] 104 | else: 105 | cost_volume[:, :, i, :, :] = feat_l - feat_r 106 | 107 | return cost_volume 108 | 109 | 110 | 111 | 112 | 113 | def _build_volume_2d3_difference( feat_l, feat_r, disp, maxdisp=3, stride=1): 114 | size = feat_l.size() 115 | batch_disp = disp[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, 1, size[-2], size[-1]) 116 | batch_shift = torch.arange(-maxdisp + 1, maxdisp, device='cuda').repeat(size[0])[:, None, None, None] * stride 117 | batch_disp = batch_disp - batch_shift.float() 118 | batch_feat_l = feat_l[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, size[-3], size[-2], size[-1]) 119 | batch_feat_r = feat_r[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, size[-3], size[-2], size[-1]) 120 | 121 | # cost = torch.norm(batch_feat_l - warp(batch_feat_r, batch_disp), 1, 1) 122 | 123 | cost = batch_feat_l - warp(batch_feat_r, batch_disp) 124 | 125 | cost = cost.view(size[0], size[1], -1, size[2], size[3]) 126 | 127 | return cost.contiguous() 128 | 129 | 130 | 131 | 132 | def warp(x, disp): 133 | """ 134 | warp an image/tensor (im2) back to im1, according to the optical flow 135 | x: [B, C, H, W] (im2) 136 | flo: [B, 2, H, W] flow 137 | """ 138 | B, C, H, W = x.size() 139 | # mesh grid 140 | xx = torch.arange(0, W, device='cuda').view(1, -1).repeat(H, 1) 141 | yy = torch.arange(0, H, device='cuda').view(-1, 1).repeat(1, W) 142 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1) 143 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1) 144 | vgrid = torch.cat((xx, yy), 1).float() 145 | 146 | # vgrid = Variable(grid) 147 | vgrid[:,:1,:,:] = vgrid[:,:1,:,:] - disp 148 | 149 | # scale grid to [-1,1] 150 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 151 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 152 | 153 | vgrid = vgrid.permute(0, 2, 3, 1) 154 | #output = nn.functional.grid_sample(x, vgrid, align_corners=True ) 155 | output = nn.functional.grid_sample(x, vgrid) 156 | return output 157 | 158 | 159 | 160 | def _build_cost_volume(cost_volume_type, refimg_fea, targetimg_fea, maxdisp): 161 | if cost_volume_type == "Concat": 162 | 163 | cost = _build_volume_2d_psmnet(refimg_fea, targetimg_fea, maxdisp=maxdisp // 8) 164 | 165 | 166 | elif cost_volume_type == "Distance_based": 167 | 168 | cost = _build_volume_2d_anynet(refimg_fea, targetimg_fea, maxdisp // 8, stride=1) 169 | cost = torch.unsqueeze(cost, 1) 170 | 171 | 172 | elif cost_volume_type == "Difference": 173 | #print("build difference") 174 | cost = _build_volume_2d_difference(refimg_fea, targetimg_fea, maxdisp // 8) 175 | #print("cost size:", cost.shape) 176 | 177 | 178 | else: 179 | AssertionError("please define cost volume types") 180 | 181 | return cost 182 | 183 | 184 | 185 | 186 | 187 | def _build_redidual_cost_volume(cost_volume_type, L2, R2, wflow, maxdisp): 188 | if cost_volume_type == "Concat": 189 | 190 | cost_residual = _build_volume_2d3_psmnet(L2, R2, wflow, maxdisp) 191 | 192 | elif cost_volume_type == "Distance_based": 193 | cost_residual = _build_volume_2d3_anynet(L2, R2, wflow, maxdisp) 194 | cost_residual = torch.unsqueeze(cost_residual, 1) 195 | 196 | 197 | elif cost_volume_type == "Difference": 198 | cost_residual = _build_volume_2d3_difference(L2, R2, wflow, maxdisp) 199 | # cost_residual = torch.unsqueeze(cost_residual, 1) 200 | 201 | else: 202 | AssertionError("please define cost volume types") 203 | 204 | return cost_residual 205 | -------------------------------------------------------------------------------- /models/cspn.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Xinjing Cheng & Peng Wang 3 | 4 | """ 5 | 6 | import torch.nn as nn 7 | import math 8 | import torch.utils.model_zoo as model_zoo 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.nn.functional as F 12 | 13 | 14 | class Affinity_Propagate(nn.Module): 15 | 16 | def __init__(self, 17 | prop_time, 18 | prop_kernel, 19 | norm_type='8sum'): 20 | """ 21 | 22 | Inputs: 23 | prop_time: how many steps for CSPN to perform 24 | prop_kernel: the size of kernel (current only support 3x3) 25 | way to normalize affinity 26 | '8sum': normalize using 8 surrounding neighborhood 27 | '8sum_abs': normalization enforcing affinity to be positive 28 | This will lead the center affinity to be 0 29 | """ 30 | super(Affinity_Propagate, self).__init__() 31 | self.prop_time = prop_time 32 | self.prop_kernel = prop_kernel 33 | assert prop_kernel == 3, 'this version only support 8 (3x3 - 1) neighborhood' 34 | 35 | self.norm_type = norm_type 36 | assert norm_type in ['8sum', '8sum_abs'] 37 | 38 | self.in_feature = 1 39 | self.out_feature = 1 40 | 41 | self.sum_conv = nn.Conv3d(in_channels=8, 42 | out_channels=1, 43 | kernel_size=(1, 1, 1), 44 | stride=1, 45 | padding=0, 46 | bias=False) 47 | weight = torch.ones(1, 8, 1, 1, 1).cuda() 48 | self.sum_conv.weight = nn.Parameter(weight) 49 | for param in self.sum_conv.parameters(): 50 | param.requires_grad = False 51 | 52 | 53 | def forward(self, guidance, blur_depth): 54 | 55 | # self.sum_conv = nn.Conv3d(in_channels=8, 56 | # out_channels=1, 57 | # kernel_size=(1, 1, 1), 58 | # stride=1, 59 | # padding=0, 60 | # bias=False) 61 | weight = torch.ones(1, 8, 1, 1, 1).cuda() 62 | self.sum_conv.weight = nn.Parameter(weight) 63 | for param in self.sum_conv.parameters(): 64 | param.requires_grad = False 65 | 66 | gate_wb, gate_sum = self.affinity_normalization(guidance) 67 | 68 | # pad input and convert to 8 channel 3D features 69 | raw_depth_input = blur_depth 70 | 71 | #blur_depht_pad = nn.ZeroPad2d((1,1,1,1)) 72 | result_depth = blur_depth 73 | 74 | 75 | 76 | for i in range(self.prop_time): 77 | # one propagation 78 | spn_kernel = self.prop_kernel 79 | #print('11111111111111111111111') 80 | result_depth = self.pad_blur_depth(result_depth) 81 | neigbor_weighted_sum = self.sum_conv(gate_wb * result_depth) 82 | neigbor_weighted_sum = neigbor_weighted_sum.squeeze(1) 83 | neigbor_weighted_sum = neigbor_weighted_sum[:, :, 1:-1, 1:-1] 84 | result_depth = neigbor_weighted_sum 85 | 86 | if '8sum' in self.norm_type: 87 | result_depth = (1.0 - gate_sum) * raw_depth_input + result_depth 88 | else: 89 | raise ValueError('unknown norm %s' % self.norm_type) 90 | 91 | 92 | return result_depth 93 | 94 | def affinity_normalization(self, guidance): 95 | 96 | # normalize features 97 | if 'abs' in self.norm_type: 98 | guidance = torch.abs(guidance) 99 | 100 | gate1_wb_cmb = guidance.narrow(1, 0 , self.out_feature) 101 | gate2_wb_cmb = guidance.narrow(1, 1 * self.out_feature, self.out_feature) 102 | gate3_wb_cmb = guidance.narrow(1, 2 * self.out_feature, self.out_feature) 103 | gate4_wb_cmb = guidance.narrow(1, 3 * self.out_feature, self.out_feature) 104 | gate5_wb_cmb = guidance.narrow(1, 4 * self.out_feature, self.out_feature) 105 | gate6_wb_cmb = guidance.narrow(1, 5 * self.out_feature, self.out_feature) 106 | gate7_wb_cmb = guidance.narrow(1, 6 * self.out_feature, self.out_feature) 107 | gate8_wb_cmb = guidance.narrow(1, 7 * self.out_feature, self.out_feature) 108 | 109 | # gate1:left_top, gate2:center_top, gate3:right_top 110 | # gate4:left_center, , gate5: right_center 111 | # gate6:left_bottom, gate7: center_bottom, gate8: right_bottm 112 | 113 | # top pad 114 | left_top_pad = nn.ZeroPad2d((0,2,0,2)) 115 | gate1_wb_cmb = left_top_pad(gate1_wb_cmb).unsqueeze(1) 116 | 117 | center_top_pad = nn.ZeroPad2d((1,1,0,2)) 118 | gate2_wb_cmb = center_top_pad(gate2_wb_cmb).unsqueeze(1) 119 | 120 | right_top_pad = nn.ZeroPad2d((2,0,0,2)) 121 | gate3_wb_cmb = right_top_pad(gate3_wb_cmb).unsqueeze(1) 122 | 123 | # center pad 124 | left_center_pad = nn.ZeroPad2d((0,2,1,1)) 125 | gate4_wb_cmb = left_center_pad(gate4_wb_cmb).unsqueeze(1) 126 | 127 | right_center_pad = nn.ZeroPad2d((2,0,1,1)) 128 | gate5_wb_cmb = right_center_pad(gate5_wb_cmb).unsqueeze(1) 129 | 130 | # bottom pad 131 | left_bottom_pad = nn.ZeroPad2d((0,2,2,0)) 132 | gate6_wb_cmb = left_bottom_pad(gate6_wb_cmb).unsqueeze(1) 133 | 134 | center_bottom_pad = nn.ZeroPad2d((1,1,2,0)) 135 | gate7_wb_cmb = center_bottom_pad(gate7_wb_cmb).unsqueeze(1) 136 | 137 | right_bottm_pad = nn.ZeroPad2d((2,0,2,0)) 138 | gate8_wb_cmb = right_bottm_pad(gate8_wb_cmb).unsqueeze(1) 139 | 140 | gate_wb = torch.cat((gate1_wb_cmb,gate2_wb_cmb,gate3_wb_cmb,gate4_wb_cmb, 141 | gate5_wb_cmb,gate6_wb_cmb,gate7_wb_cmb,gate8_wb_cmb), 1) 142 | 143 | # normalize affinity using their abs sum 144 | gate_wb_abs = torch.abs(gate_wb) 145 | abs_weight = self.sum_conv(gate_wb_abs) 146 | 147 | gate_wb = torch.div(gate_wb, abs_weight) 148 | gate_sum = self.sum_conv(gate_wb) 149 | 150 | gate_sum = gate_sum.squeeze(1) 151 | gate_sum = gate_sum[:, :, 1:-1, 1:-1] 152 | 153 | return gate_wb, gate_sum 154 | 155 | 156 | def pad_blur_depth(self, blur_depth): 157 | # top pad 158 | left_top_pad = nn.ZeroPad2d((0,2,0,2)) 159 | blur_depth_1 = left_top_pad(blur_depth).unsqueeze(1) 160 | center_top_pad = nn.ZeroPad2d((1,1,0,2)) 161 | blur_depth_2 = center_top_pad(blur_depth).unsqueeze(1) 162 | right_top_pad = nn.ZeroPad2d((2,0,0,2)) 163 | blur_depth_3 = right_top_pad(blur_depth).unsqueeze(1) 164 | 165 | # center pad 166 | left_center_pad = nn.ZeroPad2d((0,2,1,1)) 167 | blur_depth_4 = left_center_pad(blur_depth).unsqueeze(1) 168 | right_center_pad = nn.ZeroPad2d((2,0,1,1)) 169 | blur_depth_5 = right_center_pad(blur_depth).unsqueeze(1) 170 | 171 | # bottom pad 172 | left_bottom_pad = nn.ZeroPad2d((0,2,2,0)) 173 | blur_depth_6 = left_bottom_pad(blur_depth).unsqueeze(1) 174 | center_bottom_pad = nn.ZeroPad2d((1,1,2,0)) 175 | blur_depth_7 = center_bottom_pad(blur_depth).unsqueeze(1) 176 | right_bottm_pad = nn.ZeroPad2d((2,0,2,0)) 177 | blur_depth_8 = right_bottm_pad(blur_depth).unsqueeze(1) 178 | 179 | result_depth = torch.cat((blur_depth_1, blur_depth_2, blur_depth_3, blur_depth_4, 180 | blur_depth_5, blur_depth_6, blur_depth_7, blur_depth_8), 1) 181 | return result_depth 182 | 183 | 184 | def normalize_gate(self, guidance): 185 | gate1_x1_g1 = guidance.narrow(1,0,1) 186 | gate1_x1_g2 = guidance.narrow(1,1,1) 187 | gate1_x1_g1_abs = torch.abs(gate1_x1_g1) 188 | gate1_x1_g2_abs = torch.abs(gate1_x1_g2) 189 | elesum_gate1_x1 = torch.add(gate1_x1_g1_abs, gate1_x1_g2_abs) 190 | gate1_x1_g1_cmb = torch.div(gate1_x1_g1, elesum_gate1_x1) 191 | gate1_x1_g2_cmb = torch.div(gate1_x1_g2, elesum_gate1_x1) 192 | return gate1_x1_g1_cmb, gate1_x1_g2_cmb 193 | 194 | 195 | def max_of_4_tensor(self, element1, element2, element3, element4): 196 | max_element1_2 = torch.max(element1, element2) 197 | max_element3_4 = torch.max(element3, element4) 198 | return torch.max(max_element1_2, max_element3_4) 199 | 200 | def max_of_8_tensor(self, element1, element2, element3, element4, element5, element6, element7, element8): 201 | max_element1_2 = self.max_of_4_tensor(element1, element2, element3, element4) 202 | max_element3_4 = self.max_of_4_tensor(element5, element6, element7, element8) 203 | return torch.max(max_element1_2, max_element3_4) 204 | 205 | 206 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.parallel 6 | import torch.optim as optim 7 | import torch.utils.data 8 | import torch.nn.functional as F 9 | import time 10 | from torch.autograd import Variable 11 | from dataloader import listflowfile as lt 12 | from dataloader import SecenFlowLoader as DA 13 | import utils.logger as logger 14 | from utils.flops_hook import profile 15 | from models.LWANet import * 16 | 17 | 18 | parser = argparse.ArgumentParser(description='LWANet with Sceneflow dataset') 19 | parser.add_argument('--maxdisp', type=int, default=192, help='maxium disparity') 20 | parser.add_argument('--loss_weights', type=float, nargs='+', default=[1., 1.]) 21 | parser.add_argument('--maxdisplist', type=int, nargs='+', default=[24, 3, 3]) 22 | parser.add_argument('--lr', type=float, default=5e-4, help='learning rate') 23 | parser.add_argument('--with_cspn', type =bool, default= True, help='with cspn network or not') 24 | parser.add_argument('--datapath', default='/data6/wsgan/SenceFlow/train/', help='datapath') 25 | parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train') 26 | parser.add_argument('--train_bsize', type=int, default=16, help='batch size for training (default: 12)') 27 | parser.add_argument('--test_bsize', type=int, default=8, help='batch size for testing (default: 8)') 28 | parser.add_argument('--save_path', type=str, default='./results/sceneflow/', help='the path of saving checkpoints and log') 29 | parser.add_argument('--resume', type=str, default=None, help='resume path') 30 | parser.add_argument('--print_freq', type=int, default=400, help='print frequence') 31 | 32 | parser.add_argument('--model_types', type=str, default='LWANet', help='model_types : 3D, P3D') 33 | parser.add_argument('--conv_3d_types1', type=str, default='P3D', help='model_types : 3D, P3D ') 34 | parser.add_argument('--conv_3d_types2', type=str, default='P3D', help='model_types : 3D, P3D') 35 | parser.add_argument('--cost_volume', type=str, default='Difference', help='cost_volume type : "Concat" , "Difference" or "Distance_based" ') 36 | parser.add_argument('--train', type =bool, default=True, help='train or test ') 37 | 38 | 39 | args = parser.parse_args() 40 | 41 | #CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py 42 | 43 | 44 | def main(): 45 | global args 46 | 47 | train_left_img, train_right_img, train_left_disp, test_left_img, test_right_img, test_left_disp = lt.dataloader( 48 | args.datapath) 49 | TrainImgLoader = torch.utils.data.DataLoader( 50 | DA.myImageFloder(train_left_img, train_right_img, train_left_disp, True), 51 | batch_size=args.train_bsize, shuffle=True, num_workers=4, drop_last=False) 52 | TestImgLoader = torch.utils.data.DataLoader( 53 | DA.myImageFloder(test_left_img, test_right_img, test_left_disp, False), 54 | batch_size=args.test_bsize, shuffle=False, num_workers=4, drop_last=False) 55 | 56 | if not os.path.isdir(args.save_path): 57 | os.makedirs(args.save_path) 58 | log = logger.setup_logger(args.save_path + '/training.log') 59 | for key, value in sorted(vars(args).items()): 60 | log.info(str(key) + ': ' + str(value)) 61 | 62 | 63 | model = LWANet(args) 64 | 65 | 66 | # FLOPs, params = count_flops(model.cuda()) 67 | # log.info('Number of model parameters: {}'.format(params)) 68 | # log.info('Number of model FLOPs: {}'.format(FLOPs)) 69 | 70 | 71 | model = nn.DataParallel(model).cuda() 72 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) 73 | 74 | args.start_epoch = 0 75 | if args.resume: 76 | if os.path.isfile(args.resume): 77 | log.info("=> loading checkpoint '{}'".format(args.resume)) 78 | checkpoint = torch.load(args.resume) 79 | args.start_epoch = checkpoint['epoch'] 80 | model.load_state_dict(checkpoint['state_dict']) 81 | optimizer.load_state_dict(checkpoint['optimizer']) 82 | log.info("=> loaded checkpoint '{}' (epoch {})" 83 | .format(args.resume, checkpoint['epoch'])) 84 | else: 85 | log.info("=> no checkpoint found at '{}'".format(args.resume)) 86 | log.info("=> Will start from scratch.") 87 | else: 88 | log.info('Not Resume') 89 | 90 | start_full_time = time.time() 91 | 92 | if args.train: 93 | for epoch in range(args.start_epoch, args.epochs): 94 | log.info('This is {}-th epoch'.format(epoch)) 95 | 96 | train(TrainImgLoader, model, optimizer, log, epoch) 97 | 98 | savefilename = args.save_path + '/checkpoint_' + str(epoch) + '.tar' 99 | 100 | torch.save({ 101 | 'epoch': epoch, 102 | 'state_dict': model.state_dict(), 103 | 'optimizer': optimizer.state_dict(), 104 | }, savefilename) 105 | 106 | if not epoch % 10: 107 | test(TestImgLoader, model, log) 108 | 109 | test(TestImgLoader, model, log) 110 | log.info('full training time = {:.2f} Hours'.format((time.time() - start_full_time) / 3600)) 111 | 112 | 113 | def train(dataloader, model, optimizer, log, epoch=0): 114 | 115 | 116 | stages = 2 117 | losses = [AverageMeter() for _ in range(stages)] 118 | length_loader = len(dataloader) 119 | 120 | model.train() 121 | 122 | for batch_idx, (imgL, imgR, disp_L) in enumerate(dataloader): 123 | imgL = imgL.float().cuda() 124 | imgR = imgR.float().cuda() 125 | disp_L = disp_L.float().cuda() 126 | 127 | optimizer.zero_grad() 128 | mask = (disp_L < args.maxdisp) & (disp_L > 0) 129 | if mask.float().sum() == 0: 130 | continue 131 | 132 | mask.detach_() 133 | 134 | outputs, self_supervised_loss = model(imgL, imgR) 135 | stages = len(outputs) 136 | 137 | outputs = [torch.squeeze(output, 1) for output in outputs] 138 | 139 | loss = [args.loss_weights[x] * F.smooth_l1_loss(outputs[x][mask], disp_L[mask], size_average=True) 140 | for x in range(stages)] 141 | 142 | sum(loss).backward() 143 | optimizer.step() 144 | 145 | for idx in range(stages): 146 | losses[idx].update(loss[idx].item()/args.loss_weights[idx]) 147 | 148 | if batch_idx % args.print_freq ==0: 149 | info_str = ['Stage {} = {:.2f}({:.2f})'.format(x, losses[x].val, losses[x].avg) for x in range(stages)] 150 | info_str = '\t'.join(info_str) 151 | 152 | log.info('Epoch{} [{}/{}] {}'.format( 153 | epoch, batch_idx, length_loader, info_str)) 154 | info_str = '\t'.join(['Stage {} = {:.2f}'.format(x, losses[x].avg) for x in range(stages)]) 155 | log.info('Average train loss = ' + info_str) 156 | 157 | 158 | 159 | def test(dataloader, model, log): 160 | 161 | stages = 2 162 | EPEs = [AverageMeter() for _ in range(stages)] 163 | length_loader = len(dataloader) 164 | 165 | model.eval() 166 | 167 | inference_time = 0 168 | for batch_idx, (imgL, imgR, disp_L) in enumerate(dataloader): 169 | imgL = imgL.float().cuda() 170 | imgR = imgR.float().cuda() 171 | disp_L = disp_L.float().cuda() 172 | 173 | mask = disp_L < args.maxdisp 174 | with torch.no_grad(): 175 | 176 | time_start = time.perf_counter() 177 | 178 | 179 | outputs, monoloss = model(imgL, imgR) 180 | 181 | single_inference_time = time.perf_counter() - time_start 182 | 183 | inference_time += single_inference_time 184 | 185 | 186 | stages = len(outputs) 187 | for x in range(stages): 188 | if len(disp_L[mask]) == 0: 189 | EPEs[x].update(0) 190 | continue 191 | output = torch.squeeze(outputs[x], 1) 192 | output = output[:, 4:, :] 193 | EPEs[x].update((output[mask] - disp_L[mask]).abs().mean()) 194 | 195 | if batch_idx % args.print_freq == 0: 196 | info_str = '\t'.join(['Stage {} = {:.2f}({:.2f})'.format(x, EPEs[x].val, EPEs[x].avg) for x in range(stages)]) 197 | 198 | log.info('[{}/{}] {}'.format( 199 | batch_idx, length_loader, info_str)) 200 | 201 | log.info(('=> Mean inference time for %d images: %.3fs' % ( 202 | length_loader, inference_time / length_loader))) 203 | 204 | info_str = ', '.join(['Stage {}={:.2f}'.format(x, EPEs[x].avg) for x in range(stages)]) 205 | log.info('Average test EPE = ' + info_str) 206 | 207 | 208 | def adjust_learning_rate(optimizer, epoch): 209 | if epoch <= 20: 210 | lr = args.lr 211 | 212 | elif 20 loaded pretrained model '{}'" 99 | .format(args.pretrained)) 100 | else: 101 | log.info("=> no pretrained model found at '{}'".format(args.pretrained)) 102 | log.info("=> Will start from scratch.") 103 | args.start_epoch = 0 104 | 105 | 106 | cudnn.benchmark = True 107 | 108 | if args.adaptation_type == "self_supervise": 109 | model.train() 110 | loss_file = open(args.save_path + '/self_supervise' + '.txt', 'w') 111 | 112 | 113 | elif args.adaptation_type == "GT_supervise": 114 | model.train() 115 | loss_file = open(args.save_path + '/GT_supervise' + '.txt', 'w') 116 | 117 | 118 | elif args.adaptation_type == "no_supervise": 119 | 120 | loss_file = open(args.save_path + '/no_supervise' + '.txt', 'w') 121 | 122 | 123 | train(TrainImgLoader, model, optimizer, log, loss_file, args) 124 | 125 | 126 | 127 | def train(dataloader, model, optimizer, log, loss_file, args): 128 | 129 | losses = [AverageMeter() for _ in range(2)] 130 | length_loader = len(dataloader) 131 | D1s = [AverageMeter() for _ in range(2)] 132 | 133 | start_full_time = time.time() 134 | for batch_idx, (imgL, imgR, disp_L) in enumerate(dataloader): 135 | imgL = imgL.float().cuda() 136 | imgR = imgR.float().cuda() 137 | disp_L = disp_L.float().cuda() 138 | #print('train imgR size:', imgR.shape) 139 | 140 | optimizer.zero_grad() 141 | mask = disp_L > 0 142 | mask = mask*(disp_L<192) 143 | mask.detach_() 144 | 145 | single_update_time=time.time() 146 | 147 | #outputs = model(imgL, imgR) 148 | if args.adaptation_type == "no_supervise": 149 | model.eval() 150 | with torch.no_grad(): 151 | pred, mono_loss = model(imgL, imgR) 152 | 153 | outputs = [torch.squeeze(output, 1) for output in pred] 154 | 155 | num_out = len(pred) 156 | loss = [args.loss_weights[x] * F.smooth_l1_loss(outputs[x][mask], disp_L[mask], size_average=True) 157 | for x in range(num_out)] 158 | 159 | 160 | num_out = len(pred) 161 | 162 | 163 | elif args.adaptation_type == "self_supervise": 164 | model.train() 165 | 166 | pred, mono_loss = model(imgL, imgR) 167 | outputs = [torch.squeeze(output, 1) for output in pred] 168 | num_out = len(pred) 169 | loss = [args.loss_weights[x] * F.smooth_l1_loss(outputs[x][mask], disp_L[mask], size_average=True) 170 | for x in range(num_out)] 171 | 172 | sum(mono_loss).backward() 173 | 174 | optimizer.step() 175 | 176 | elif args.adaptation_type == "GT_supervise": 177 | model.train() 178 | 179 | pred, mono_loss = model(imgL, imgR) 180 | 181 | outputs = [torch.squeeze(output, 1) for output in pred] 182 | 183 | num_out = len(pred) 184 | loss = [args.loss_weights[x] * F.smooth_l1_loss(outputs[x][mask], disp_L[mask], size_average=True) 185 | for x in range(num_out)] 186 | 187 | sum(loss).backward() 188 | optimizer.step() 189 | 190 | 191 | 192 | print('sigle_update_time: {:.4f} seconds'.format(time.time() - single_update_time)) 193 | # image out and error estimation 194 | 195 | # three pixel error 196 | 197 | output = torch.squeeze(pred[1], 1) 198 | D1s[1].update(error_estimating(output, disp_L).item()) 199 | print('output size:', output.shape) 200 | 201 | 202 | 203 | # save the adaptation disparity 204 | if args.save_disparity : 205 | 206 | plt.imshow(output.squeeze(0).cpu().detach().numpy()) 207 | plt.axis('off') 208 | 209 | plt.gcf().set_size_inches(1216 / 100, 320 / 100) 210 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 211 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 212 | plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0) 213 | plt.margins(0, 0) 214 | 215 | plt.savefig(args.save_path+'/disparity/{}.png'.format(batch_idx)) 216 | 217 | # if args.save_disparity: 218 | # 219 | # imgL = imgL.squeeze(0).permute(1,2,0) 220 | # #print("imgL size:", imgL.shape) 221 | # plt.imshow(imgL.cpu().detach().numpy()) 222 | # plt.axis('off') 223 | # 224 | # plt.gcf().set_size_inches(1216 / 100, 320 / 100) 225 | # plt.gca().xaxis.set_major_locator(plt.NullLocator()) 226 | # plt.gca().yaxis.set_major_locator(plt.NullLocator()) 227 | # plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0) 228 | # plt.margins(0, 0) 229 | # 230 | # plt.savefig(args.save_path + '/disparity/{}.png'.format(batch_idx)) 231 | # 232 | 233 | 234 | 235 | loss_file.write('{:.4f}\n'.format(D1s[1].val)) 236 | 237 | for idx in range(num_out): 238 | losses[idx].update(loss[idx].item()) 239 | 240 | 241 | info_str = ['Stage {} = {:.2f}({:.2f})'.format(x, losses[x].val, losses[x].avg) for x in range(num_out)] 242 | info_str = '\t'.join(info_str) 243 | 244 | log.info('Epoch{} [{}/{}] {}'.format( 1, batch_idx, length_loader, info_str)) 245 | 246 | end_time = time.time() 247 | 248 | log.info('full training time = {:.2f} Hours, full train time = {:.4f} seconds'.format( 249 | (end_time - start_full_time) / 3600, end_time - start_full_time)) 250 | 251 | # summary 252 | info_str = ', '.join(['Stage {}={:.4f}'.format(x, D1s[x].avg) for x in range(num_out)]) 253 | 254 | log.info('Average test 3-Pixel Error = ' + info_str) 255 | 256 | info_str = '\t'.join(['Stage {} = {:.2f}'.format(x, losses[x].avg) for x in range(num_out)]) 257 | log.info('Average train loss = ' + info_str) 258 | 259 | loss_file.close() 260 | 261 | 262 | 263 | def error_estimating(disp, ground_truth, maxdisp=192): 264 | 265 | gt = ground_truth 266 | mask = gt > 0 267 | mask = mask * (gt < maxdisp) 268 | 269 | errmap = torch.abs(disp - gt) 270 | err3 = ((errmap[mask] > 3.) & (errmap[mask] / gt[mask] > 0.05)).sum() 271 | return err3.float() / mask.sum().float() 272 | 273 | 274 | 275 | class AverageMeter(object): 276 | """Computes and stores the average and current value""" 277 | 278 | def __init__(self): 279 | self.reset() 280 | 281 | def reset(self): 282 | self.val = 0 283 | self.avg = 0 284 | self.sum = 0 285 | self.count = 0 286 | 287 | def update(self, val, n=1): 288 | self.val = val 289 | self.sum += val * n 290 | self.count += n 291 | self.avg = self.sum / self.count 292 | 293 | 294 | if __name__ == '__main__': 295 | main() 296 | 297 | 298 | 299 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.parallel 6 | import torch.optim as optim 7 | import torch.utils.data 8 | import torch.nn.functional as F 9 | import time 10 | from dataloader import KITTILoader as DA 11 | from dataloader import KITTIdatalist as ls 12 | import utils.logger as logger 13 | import torch.backends.cudnn as cudnn 14 | from models.LWANet import * 15 | 16 | import pdb 17 | 18 | # 查看GPU使用情況 19 | # watch --color -n1 gpustat -cpu 20 | 21 | 22 | parser = argparse.ArgumentParser(description='LWANet fintune on KITTI') 23 | parser.add_argument('--maxdisp', type=int, default=192, help='maxium disparity') 24 | parser.add_argument('--loss_weights', type=float, nargs='+', default=[1., 1.]) 25 | parser.add_argument('--max_disparity', type=int, default=192) 26 | parser.add_argument('--maxdisplist', type=int, nargs='+', default=[24, 3, 3]) 27 | parser.add_argument('--with_cspn', type =bool, default= True, help='with cspn network or not') 28 | parser.add_argument('--cost_volume', type=str, default='Difference', help='cost_volume type : "Concat" , "Difference" or "Distance_based"') 29 | parser.add_argument('--lr', type=float, default=5e-4*0.5, help='learning rate') 30 | parser.add_argument('--epochs', type=int, default=1001, help='number of epochs to train') 31 | parser.add_argument('--train_bsize', type=int, default=8, help='batch size for training (default: 8)') 32 | parser.add_argument('--test_bsize', type=int, default=8,help='batch size for testing (default: 8)') 33 | parser.add_argument('--resume', type=str, default= None, help='resume path') 34 | parser.add_argument('--print_freq', type=int, default=10, help='print frequence') 35 | parser.add_argument('--pretrained', type=str, default=None, help='pretrained model path') 36 | parser.add_argument('--model_types', type=str, default='LWANet', help='model_types : 3D OR P3D') 37 | parser.add_argument('--conv_3d_types1', type=str, default='P3D', help='model_types : 3D, P3D ') 38 | parser.add_argument('--conv_3d_types2', type=str, default='P3D', help='model_types : 3D, P3D') 39 | 40 | 41 | parser.add_argument('--save_path', type=str, default='/results/finetune2015/',help='the path of saving checkpoints and log') 42 | parser.add_argument('--split_for_val', type =bool, default=False, help='finetune for submission or for validation') 43 | parser.add_argument('--datatype', default='mix', help='finetune dataset: 2012, 2015, mix') 44 | parser.add_argument('--datapath2015', default='/data6/wsgan/KITTI/KITTI2015/training/', help='datapath') 45 | parser.add_argument('--datapath2012', default='/data6/wsgan/KITTI/KITTI2012/training/', help='datapath') 46 | 47 | 48 | args = parser.parse_args() 49 | 50 | 51 | #CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python finetune.py 52 | 53 | def main(): 54 | global args 55 | log = logger.setup_logger(args.save_path + '/training.log') 56 | 57 | if args.datatype == '2015': 58 | 59 | all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader2015( 60 | args.datapath2015, log, split=args.split_for_val) 61 | 62 | elif args.datatype == '2012': 63 | 64 | all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader2012( 65 | args.datapath2012, log, split = False) 66 | 67 | elif args.datatype == 'mix': 68 | 69 | all_left_img_2015, all_right_img_2015, all_left_disp_2015, test_left_img_2015, test_right_img_2015, test_left_disp_2015 = ls.dataloader2015( 70 | args.datapath2015, log, split=False) 71 | all_left_img_2012, all_right_img_2012, all_left_disp_2012, test_left_img_2012, test_right_img_2012, test_left_disp_2012 = ls.dataloader2012( 72 | args.datapath2012, log, split=False) 73 | all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = \ 74 | all_left_img_2015 + all_left_img_2012, all_right_img_2015 + all_right_img_2012, \ 75 | all_left_disp_2015 + all_left_disp_2012, test_left_img_2015 + test_left_img_2012, \ 76 | test_right_img_2015 + test_right_img_2012, test_left_disp_2015 + test_left_disp_2012 77 | else: 78 | 79 | AssertionError("please define the finetune dataset") 80 | 81 | TrainImgLoader = torch.utils.data.DataLoader( 82 | DA.myImageFloder(all_left_img, all_right_img, all_left_disp, True), 83 | batch_size=args.train_bsize, shuffle=True, num_workers=4, drop_last=False) 84 | 85 | TestImgLoader = torch.utils.data.DataLoader( 86 | DA.myImageFloder(test_left_img, test_right_img, test_left_disp, False), 87 | batch_size=args.test_bsize, shuffle=False, num_workers=4, drop_last=False) 88 | 89 | if not os.path.isdir(args.save_path): 90 | os.makedirs(args.save_path) 91 | for key, value in sorted(vars(args).items()): 92 | log.info(str(key) + ': ' + str(value)) 93 | 94 | 95 | model = LWANet(args) 96 | 97 | 98 | model = nn.DataParallel(model).cuda() 99 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) 100 | log.info('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 101 | 102 | if args.pretrained: 103 | if os.path.isfile(args.pretrained): 104 | checkpoint = torch.load(args.pretrained) 105 | model.load_state_dict(checkpoint['state_dict'], strict=False) 106 | log.info("=> loaded pretrained model '{}'" 107 | .format(args.pretrained)) 108 | else: 109 | log.info("=> no pretrained model found at '{}'".format(args.pretrained)) 110 | log.info("=> Will start from scratch.") 111 | args.start_epoch = 0 112 | if args.resume: 113 | if os.path.isfile(args.resume): 114 | log.info("=> loading checkpoint '{}'".format(args.resume)) 115 | checkpoint = torch.load(args.resume) 116 | model.load_state_dict(checkpoint['state_dict'], strict=False) 117 | optimizer.load_state_dict(checkpoint['optimizer']) 118 | args.start_epoch = checkpoint['epoch'] + 1 119 | log.info("=> loaded checkpoint '{}' (epoch {})" 120 | .format(args.resume, checkpoint['epoch'])) 121 | else: 122 | log.info("=> no checkpoint found at '{}'".format(args.resume)) 123 | log.info("=> Will start from scratch.") 124 | else: 125 | log.info('Not Resume') 126 | cudnn.benchmark = True 127 | 128 | start_full_time = time.time() 129 | 130 | 131 | 132 | for epoch in range(args.start_epoch, args.epochs): 133 | log.info('This is {}-th epoch'.format(epoch)) 134 | adjust_learning_rate(optimizer, epoch) 135 | 136 | train(TrainImgLoader, model, optimizer, log, epoch) 137 | 138 | if epoch % 100 == 0: 139 | savefilename = args.save_path + '/finetune_' + str(epoch) + '.tar' 140 | torch.save({ 141 | 'epoch': epoch, 142 | 'state_dict': model.state_dict(), 143 | 'optimizer': optimizer.state_dict(), 144 | }, savefilename) 145 | 146 | 147 | 148 | if epoch % 20 == 0: 149 | test(TestImgLoader, model, log) 150 | 151 | 152 | 153 | test(TestImgLoader, model, log) 154 | log.info('full training time = {:.2f} Hours'.format((time.time() - start_full_time) / 3600)) 155 | 156 | 157 | 158 | def train(dataloader, model, optimizer, log, epoch=0): 159 | 160 | stages = 2 161 | losses = [AverageMeter() for _ in range(stages)] 162 | length_loader = len(dataloader) 163 | 164 | model.train() 165 | 166 | for batch_idx, (imgL, imgR, disp_L) in enumerate(dataloader): 167 | 168 | imgL = imgL.float().cuda() 169 | imgR = imgR.float().cuda() 170 | disp_L = disp_L.float().cuda() 171 | 172 | optimizer.zero_grad() 173 | mask = (disp_L > 0) & (disp_L < args.maxdisp) 174 | mask.detach_() 175 | 176 | pred, mono_loss = model(imgL, imgR) 177 | 178 | outputs = [torch.squeeze(output, 1) for output in pred] 179 | 180 | num_out = len(pred) 181 | loss = [args.loss_weights[x] * F.smooth_l1_loss(outputs[x][mask], disp_L[mask], size_average=True) 182 | for x in range(num_out)] 183 | 184 | sum(loss).backward() 185 | 186 | optimizer.step() 187 | 188 | for idx in range(num_out): 189 | losses[idx].update(loss[idx].item()) 190 | 191 | if batch_idx % args.print_freq == 0: 192 | info_str = ['Stage {} = {:.2f}({:.2f})'.format(x, losses[x].val, losses[x].avg) for x in range(num_out)] 193 | info_str = '\t'.join(info_str) 194 | 195 | log.info('Epoch{} [{}/{}] {}'.format( 196 | epoch, batch_idx, length_loader, info_str)) 197 | 198 | info_str = '\t'.join(['Stage {} = {:.2f}'.format(x, losses[x].avg) for x in range(1)]) 199 | log.info('Average train loss = ' + info_str) 200 | 201 | 202 | def test(dataloader, model, log): 203 | 204 | stages = 3 + args.with_cspn 205 | D1s = [AverageMeter() for _ in range(stages)] 206 | length_loader = len(dataloader) 207 | 208 | model.eval() 209 | 210 | total_inference_time = 0 211 | for batch_idx, (imgL, imgR, disp_L) in enumerate(dataloader): 212 | 213 | imgL = imgL.float().cuda() 214 | imgR = imgR.float().cuda() 215 | disp_L = disp_L.float().cuda() 216 | 217 | with torch.no_grad(): 218 | 219 | start_time = time.time() 220 | outputs, mono_loss = model(imgL, imgR) 221 | print(time.time() - start_time) 222 | total_inference_time += time.time() - start_time 223 | 224 | num_out = len(outputs) 225 | for x in range(num_out): 226 | 227 | output = torch.squeeze(outputs[x], 1) 228 | D1s[x].update(error_estimating(output, disp_L).item()) 229 | 230 | info_str = '\t'.join(['Stage {} = {:.4f}({:.4f})'.format(x, D1s[x].val, D1s[x].avg) for x in range(num_out)]) 231 | 232 | log.info('[{}/{}] {}'.format( batch_idx, length_loader, info_str)) 233 | 234 | log.info("mean inference time: %.3fs " % (total_inference_time / length_loader)) 235 | info_str = ', '.join(['Stage {}={:.4f}'.format(x, D1s[x].avg) for x in range(num_out)]) 236 | log.info('Average test 3-Pixel Error = ' + info_str) 237 | 238 | 239 | def error_estimating(disp, ground_truth, maxdisp=192): 240 | 241 | gt = ground_truth 242 | mask = gt > 0 243 | mask = mask * (gt < maxdisp) 244 | errmap = torch.abs(disp - gt) 245 | err3 = ((errmap[mask] > 3.) & (errmap[mask] / gt[mask] > 0.05)).sum() 246 | 247 | return err3.float() / mask.sum().float() 248 | 249 | 250 | def adjust_learning_rate(optimizer, epoch): 251 | if epoch <= 600: 252 | lr = args.lr 253 | 254 | elif 600< epoch <= 1000: 255 | lr = args.lr*0.1 256 | 257 | else: 258 | lr = args.lr*0.01 259 | 260 | for param_group in optimizer.param_groups: 261 | param_group['lr'] = lr 262 | 263 | 264 | 265 | class AverageMeter(object): 266 | """Computes and stores the average and current value""" 267 | 268 | def __init__(self): 269 | self.reset() 270 | 271 | def reset(self): 272 | self.val = 0 273 | self.avg = 0 274 | self.sum = 0 275 | self.count = 0 276 | 277 | def update(self, val, n=1): 278 | self.val = val 279 | self.sum += val * n 280 | self.count += n 281 | self.avg = self.sum / self.count 282 | 283 | 284 | 285 | if __name__ == '__main__': 286 | main() 287 | 288 | 289 | 290 | -------------------------------------------------------------------------------- /One_cycle.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch.nn.parallel 4 | import torch.optim as optim 5 | import torch.utils.data 6 | import torch.nn.functional as F 7 | import time 8 | import matplotlib.pyplot as plt 9 | from dataloader import KITTILoader_One_cycle as DA 10 | import utils.logger as logger 11 | import torch.backends.cudnn as cudnn 12 | 13 | import models 14 | 15 | 16 | from models.LWADNet import * 17 | 18 | 19 | parser = argparse.ArgumentParser(description='Anynet fintune on KITTI') 20 | parser.add_argument('--maxdisp', type=int, default=192, 21 | help='maxium disparity') 22 | parser.add_argument('--loss_weights', type=float, nargs='+', default=[1., 1., 1., 1.]) 23 | parser.add_argument('--max_disparity', type=int, default=192) 24 | parser.add_argument('--maxdisplist', type=int, nargs='+', default=[24, 3, 3]) 25 | parser.add_argument('--datatype', default='2015', 26 | help='datapath') 27 | parser.add_argument('--datapath', default='/home/wsgan/KITTI_DATASET/KITTI2015/training/', help='datapath') 28 | #parser.add_argument('--datapath', default='/home/um/GAN/Anynet/kitti2012/training/', help='datapath') 29 | 30 | parser.add_argument('--epochs', type=int, default=200, 31 | help='number of epochs to train') 32 | parser.add_argument('--train_bsize', type=int, default=1, 33 | help='batch size for training (default: 6)') 34 | parser.add_argument('--test_bsize', type=int, default=1, 35 | help='batch size for testing (default: 8)') 36 | 37 | 38 | parser.add_argument('--lr', type=float, default=5e-4, 39 | help='learning rate') 40 | parser.add_argument('--with_cspn', action='store_true', help='with spn network or not') 41 | 42 | 43 | parser.add_argument('--model_types', type=str, default='original', help='model_types : LWANet_3D, mix, original') 44 | parser.add_argument('--conv_3d_types1', type=str, default='separate_only', help='model_types : normal, P3D, separate_only, ONLY_2D ') 45 | parser.add_argument('--conv_3d_types2', type=str, default='separate_only', help='model_types : normal, P3D, separate_only, ONLY_2D') 46 | parser.add_argument('--cost_volume', type=str, default='Difference', help='cost_volume type : "Concat" , "Difference" or "Distance_based" ') 47 | 48 | 49 | parser.add_argument('--adaptation_type', default='GT_supervise', help='adaptation_type : self_supervise, GT_supervise, no_supervise') 50 | 51 | parser.add_argument('--pretrained', type=str, default='/home/wsgan/LWANet/results/pretrain/original_Difference/separate_only/checkpoint_49.tar', 52 | help='pretrained model path') 53 | parser.add_argument('--save_path', type=str, default='./results/finetune_One_cycle/GT_supervise/', 54 | help='the path of saving checkpoints and log') 55 | 56 | 57 | 58 | args = parser.parse_args() 59 | 60 | if args.datatype == '2015': 61 | from dataloader import KITTIloader2015_One_cycle as ls 62 | 63 | elif args.datatype == '2012': 64 | from dataloader import KITTIloader2012 as ls 65 | 66 | 67 | 68 | # python One_cycle.py --with_cspn 69 | 70 | def main(): 71 | global args 72 | log = logger.setup_logger(args.save_path + '/training.log') 73 | #log1 = logger.setup_logger(args.save_path + '/self_adaptive_loss.log') 74 | 75 | train_left_img, train_right_img, train_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader( 76 | args.datapath,log) 77 | 78 | TrainImgLoader = torch.utils.data.DataLoader( 79 | DA.myImageFloder(train_left_img, train_right_img, train_left_disp, True), 80 | batch_size=args.train_bsize, shuffle=True, num_workers=4, drop_last=False) 81 | 82 | TestImgLoader = torch.utils.data.DataLoader( 83 | DA.myImageFloder(test_left_img, test_right_img, test_left_disp, False), 84 | batch_size=args.test_bsize, shuffle=False, num_workers=4, drop_last=False) 85 | 86 | if not os.path.isdir(args.save_path): 87 | os.makedirs(args.save_path) 88 | 89 | if not os.path.isdir(args.save_path+'/image'): 90 | os.makedirs(args.save_path+'/image') 91 | 92 | for key, value in sorted(vars(args).items()): 93 | log.info(str(key) + ': ' + str(value)) 94 | 95 | 96 | model = models.LWADNet.AnyNet(args) 97 | 98 | 99 | model = nn.DataParallel(model).cuda() 100 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) 101 | log.info('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 102 | 103 | if args.pretrained: 104 | if os.path.isfile(args.pretrained): 105 | checkpoint = torch.load(args.pretrained) 106 | model.load_state_dict(checkpoint['state_dict'], strict=False) 107 | log.info("=> loaded pretrained model '{}'" 108 | .format(args.pretrained)) 109 | else: 110 | log.info("=> no pretrained model found at '{}'".format(args.pretrained)) 111 | log.info("=> Will start from scratch.") 112 | args.start_epoch = 0 113 | 114 | cudnn.benchmark = True 115 | 116 | start_full_time = time.time() 117 | loss_file = open(args.save_path + '/self_supervise' + '.txt', 'w') 118 | 119 | for epoch in range(args.start_epoch, args.epochs): 120 | log.info('This is {}-th epoch'.format(epoch)) 121 | 122 | D1s= train(TrainImgLoader, model, optimizer, log, epoch) 123 | loss_file.write('{:.4f}\n'.format(D1s)) 124 | 125 | 126 | loss_file.close() 127 | 128 | log.info('full training time = {:.2f} Hours'.format((time.time() - start_full_time) / 3600)) 129 | 130 | 131 | def train(dataloader, model, optimizer, log, epoch=0): 132 | 133 | 134 | 135 | stages = 3 + args.with_cspn 136 | losses = [AverageMeter() for _ in range(stages)] 137 | length_loader = len(dataloader) 138 | D1s = [AverageMeter() for _ in range(2)] 139 | 140 | 141 | model.train() 142 | 143 | #loss_file = open(args.save_path + '/self_adaptive_loss' + '.txt', 'w') 144 | 145 | for batch_idx, (imgL, imgR, disp_L) in enumerate(dataloader): 146 | imgL = imgL.float().cuda() 147 | imgR = imgR.float().cuda() 148 | disp_L = disp_L.float().cuda() 149 | #print(' disp_L size:', disp_L) 150 | 151 | 152 | 153 | optimizer.zero_grad() 154 | mask = disp_L > 0 155 | mask.detach_() 156 | 157 | #outputs = model(imgL, imgR) 158 | pred, mono_loss = model(imgL, imgR) 159 | 160 | for x in range(len(pred)): 161 | output = torch.squeeze(pred[x], 1) 162 | D1s[x].update(error_estimating(output, disp_L).item()) 163 | 164 | # loss_file.write('{:.4f}\n'.format(D1s[1].val)) 165 | # loss_file.close() 166 | 167 | # print('len(outputs)', len(outputs)) 168 | pred = [pred for pred in pred] 169 | num_out = len(pred) 170 | #print('num_out:', num_out) 171 | 172 | 173 | outputs = [torch.squeeze(output, 1) for output in pred] 174 | 175 | output_save = outputs[1].squeeze(0) 176 | #print('output_save:', output_save.shape) 177 | 178 | #io.imsave(args.save_path + '/epoch {}.png'.format(epoch), (output_save.cpu().data.numpy() )) 179 | 180 | plt.imshow(output_save.detach().cpu().numpy()) 181 | plt.axis('off') 182 | 183 | #plt.savefig(args.save_path+'/image'+ '/epoch {} D1 {:.4f}.png'.format(epoch, D1s[1].val)) 184 | plt.savefig(args.save_path + '/image' + '/epoch {} D1 {:.4f}.png'.format(epoch, D1s[1].val), bbox_inches = 'tight', dpi= 300, pad_inches = 0) 185 | 186 | 187 | loss = [args.loss_weights[x] * F.smooth_l1_loss(outputs[x][mask], disp_L[mask], size_average=True) 188 | for x in range(num_out)] 189 | 190 | #if args.adaptation_type == "no_supervise": 191 | 192 | #sum(mono_loss).backward() 193 | sum(loss).backward() 194 | # 195 | optimizer.step() 196 | 197 | for idx in range(num_out): 198 | losses[idx].update(loss[idx].item()) 199 | 200 | if 1: 201 | info_str = ['Stage {} = {:.2f}({:.2f})'.format(x, losses[x].val, losses[x].avg) for x in range(num_out)] 202 | info_str = '\t'.join(info_str) 203 | 204 | log.info('Epoch{} [{}/{}] {}'.format( 205 | epoch, batch_idx, length_loader, info_str)) 206 | 207 | info_str = '\t'.join( 208 | ['Stage {} = {:.4f}({:.4f})'.format(x, D1s[x].val, D1s[x].avg) for x in range(num_out)]) 209 | 210 | log.info('[{}/{}] {}'.format( 211 | batch_idx, length_loader, info_str)) 212 | 213 | return D1s[1].val 214 | 215 | 216 | # info_str = '\t'.join(['Stage {} = {:.2f}'.format(x, losses[x].avg) for x in range(stages)]) 217 | # info_str = '\t'.join(['Stage {} = {:.2f}'.format(x, losses[x].avg) for x in range(2)]) 218 | # log.info('Average train loss = ' + info_str) 219 | 220 | 221 | def test(dataloader, model, log): 222 | 223 | stages = 3 + args.with_cspn 224 | D1s = [AverageMeter() for _ in range(stages)] 225 | length_loader = len(dataloader) 226 | 227 | model.eval() 228 | 229 | for batch_idx, (imgL, imgR, disp_L) in enumerate(dataloader): 230 | imgL = imgL.float().cuda() 231 | imgR = imgR.float().cuda() 232 | disp_L = disp_L.float().cuda() 233 | # print('test imgR size:', imgR.shape) 234 | 235 | # imgL = F.pad(imgL, [3, 3, 1, 0]) 236 | # imgR = F.pad(imgR, [3, 3, 1, 0]) 237 | # disp_L = F.pad(disp_L, [3, 3, 1, 0]) 238 | #print('imgR size:', imgR.shape) 239 | 240 | with torch.no_grad(): 241 | outputs, mono_loss = model(imgL, imgR, train = 0) 242 | 243 | 244 | # for x in range(stages): 245 | if args.with_cspn: 246 | # if epoch >= args.start_epoch_for_spn: 247 | # num_out = len(outputs) 248 | # else: 249 | # num_out = len(outputs) - 1 250 | num_out = len(outputs) 251 | 252 | else: 253 | num_out = len(outputs) 254 | 255 | for x in range(num_out): 256 | output = torch.squeeze(outputs[x], 1) 257 | 258 | # print('output size:', output.shape) 259 | # print('disp_L size:', disp_L.shape) 260 | D1s[x].update(error_estimating(output, disp_L).item()) 261 | 262 | 263 | info_str = '\t'.join(['Stage {} = {:.4f}({:.4f})'.format(x, D1s[x].val, D1s[x].avg) for x in range(num_out)]) 264 | 265 | 266 | log.info('[{}/{}] {}'.format( 267 | batch_idx, length_loader, info_str)) 268 | 269 | 270 | info_str = ', '.join(['Stage {}={:.4f}'.format(x, D1s[x].avg) for x in range(num_out)]) 271 | 272 | log.info('Average test 3-Pixel Error = ' + info_str) 273 | 274 | 275 | def error_estimating(disp, ground_truth, maxdisp=192): 276 | gt = ground_truth 277 | 278 | 279 | # gt = gt[:, 0:368, 50:1200] 280 | # disp = disp[:, 0:368, 50:1200] 281 | # print('gt shape:', gt.shape) 282 | 283 | #mask = gt[:, 0:368, 50:1232]> 0 284 | 285 | mask = gt > 0 286 | mask = mask * (gt < maxdisp) 287 | 288 | errmap = torch.abs(disp - gt) 289 | err3 = ((errmap[mask] > 3.) & (errmap[mask] / gt[mask] > 0.05)).sum() 290 | return err3.float() / mask.sum().float() 291 | 292 | 293 | def adjust_learning_rate(optimizer, epoch): 294 | if epoch <= 1000: 295 | lr = args.lr 296 | elif epoch <= 1500: 297 | lr = args.lr * 0.1 298 | else: 299 | lr = args.lr * 0.01 300 | for param_group in optimizer.param_groups: 301 | param_group['lr'] = lr 302 | 303 | class AverageMeter(object): 304 | """Computes and stores the average and current value""" 305 | 306 | def __init__(self): 307 | self.reset() 308 | 309 | def reset(self): 310 | self.val = 0 311 | self.avg = 0 312 | self.sum = 0 313 | self.count = 0 314 | 315 | def update(self, val, n=1): 316 | self.val = val 317 | self.sum += val * n 318 | self.count += n 319 | self.avg = self.sum / self.count 320 | 321 | 322 | def post_process_disparity(disp): 323 | _, h, w = disp[0].shape 324 | #print('disp[0].shape:', disp[0].shape) # torch.Size([1, 368, 1232]) 325 | 326 | l_disp = disp[0].cpu().numpy() 327 | #r_disp = np.fliplr(disp[1].cpu()) 328 | r_disp = disp[1].cpu().numpy() 329 | 330 | #m_disp = 0.5 * (l_disp + r_disp) 331 | 332 | l, _ = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h)) 333 | l_mask = 1.0 - np.clip(20 * (l - 0.05), 0, 1) 334 | #r_mask =np.fliplr(l_mask) 335 | # return r_mask * l_disp + l_mask * r_disp + (1.0 - l_mask - r_mask) * m_disp 336 | return l_mask * r_disp + (1.0 - l_mask ) * l_disp 337 | # benlaijiushi l_disp zhijiequdiao 338 | 339 | 340 | 341 | 342 | 343 | if __name__ == '__main__': 344 | main() 345 | 346 | 347 | 348 | --------------------------------------------------------------------------------