├── .idea
├── .gitignore
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
├── sshConfigs.xml
├── deployment.xml
├── webServers.xml
├── LWANet.iml
└── workspace.xml
├── utils
├── __init__.py
├── __pycache__
│ ├── logger.cpython-36.pyc
│ └── __init__.cpython-36.pyc
├── logger.py
├── readpfm.py
├── merge_img2video.py
├── preprocess.py
└── flops_hook.py
├── dataloader
├── __init__.py
├── __pycache__
│ ├── readpfm.cpython-36.pyc
│ ├── __init__.cpython-36.pyc
│ ├── KITTILoader.cpython-36.pyc
│ ├── listflowfile.cpython-36.pyc
│ ├── preprocess.cpython-36.pyc
│ ├── KITTILoader_mask.cpython-36.pyc
│ ├── KITTILoader_video.cpython-36.pyc
│ ├── KITTI_0028_sync.cpython-36.pyc
│ ├── KITTIloader2012.cpython-36.pyc
│ ├── KITTIloader2015.cpython-36.pyc
│ ├── SecenFlowLoader.cpython-36.pyc
│ ├── preprocess_change.cpython-36.pyc
│ ├── KITTILoader_change.cpython-36.pyc
│ ├── KITTILoader_0028_0071.cpython-36.pyc
│ ├── KITTILoader_One_cycle.cpython-36.pyc
│ ├── KITTILoader_supervised.cpython-36.pyc
│ ├── KITTIloader2015_mask.cpython-36.pyc
│ ├── KITTIloader2015_test.cpython-36.pyc
│ ├── KITTIloader2015_video.cpython-36.pyc
│ ├── KITTIloader_0028_sync.cpython-36.pyc
│ ├── KITTILoader_submit_to_2015.cpython-36.pyc
│ ├── KITTIloader2015_One_cycle.cpython-36.pyc
│ ├── KITTIloader2015_supervised.cpython-36.pyc
│ ├── KITTIloader_list_0028_0071.cpython-36.pyc
│ ├── preprocess_submit_to_2015.cpython-36.pyc
│ └── KITTIloader2015_submit_to_2015.cpython-36.pyc
├── KITTI_submission_loader.py
├── readpfm.py
├── KITTILoader_One_cycle.py
├── KITTILoader_0028_0071.py
├── KITTIloader2015_One_cycle.py
├── SecenFlowLoader.py
├── KITTILoader.py
├── KITTIdatalist.py
├── listflowfile.py
└── preprocess.py
├── models
├── __pycache__
│ ├── comm.cpython-36.pyc
│ ├── cspn.cpython-36.pyc
│ ├── LWADNet.cpython-36.pyc
│ ├── cspn_5_12.cpython-36.pyc
│ ├── LWADNet_5_12.cpython-36.pyc
│ ├── LWADNet_flops.cpython-36.pyc
│ ├── LWADNet_submodules.cpython-36.pyc
│ ├── LWADNet_submodules_BN.cpython-36.pyc
│ ├── batch_normalization.cpython-36.pyc
│ ├── LWADNet_submodules_test.cpython-36.pyc
│ └── LWADNet_submodules_IN_FRN.cpython-36.pyc
├── feature_extraction.py
├── loss.py
├── Aggregation_submodules.py
├── LWANet.py
├── cost.py
└── cspn.py
├── README.md
├── env.yaml
├── submission.py
├── main.py
├── Online_adaptation.py
├── finetune.py
└── One_cycle.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/dataloader/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/__pycache__/comm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/comm.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/cspn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/cspn.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logger.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/utils/__pycache__/logger.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/LWADNet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/LWADNet.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/readpfm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/readpfm.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/cspn_5_12.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/cspn_5_12.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/LWADNet_5_12.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/LWADNet_5_12.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/LWADNet_flops.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/LWADNet_flops.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTILoader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTILoader.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/listflowfile.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/listflowfile.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/preprocess.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/preprocess.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/LWADNet_submodules.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/LWADNet_submodules.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTILoader_mask.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTILoader_mask.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTILoader_video.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTILoader_video.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTI_0028_sync.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTI_0028_sync.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTIloader2012.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader2012.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTIloader2015.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader2015.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/SecenFlowLoader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/SecenFlowLoader.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/preprocess_change.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/preprocess_change.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/LWADNet_submodules_BN.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/LWADNet_submodules_BN.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/batch_normalization.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/batch_normalization.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTILoader_change.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTILoader_change.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/LWADNet_submodules_test.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/LWADNet_submodules_test.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTILoader_0028_0071.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTILoader_0028_0071.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTILoader_One_cycle.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTILoader_One_cycle.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTILoader_supervised.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTILoader_supervised.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTIloader2015_mask.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader2015_mask.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTIloader2015_test.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader2015_test.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTIloader2015_video.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader2015_video.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTIloader_0028_sync.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader_0028_sync.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/LWADNet_submodules_IN_FRN.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/models/__pycache__/LWADNet_submodules_IN_FRN.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTILoader_submit_to_2015.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTILoader_submit_to_2015.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTIloader2015_One_cycle.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader2015_One_cycle.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTIloader2015_supervised.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader2015_supervised.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTIloader_list_0028_0071.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader_list_0028_0071.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/preprocess_submit_to_2015.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/preprocess_submit_to_2015.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/KITTIloader2015_submit_to_2015.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANWANSHUI/LWANet/HEAD/dataloader/__pycache__/KITTIloader2015_submit_to_2015.cpython-36.pyc
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/sshConfigs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/LWANet.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
11 |
12 |
13 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 |
5 | def setup_logger(filepath):
6 | file_formatter = logging.Formatter(
7 | "[%(asctime)s %(filename)s:%(lineno)s] %(levelname)-8s %(message)s",
8 | datefmt='%Y-%m-%d %H:%M:%S',
9 | )
10 | logger = logging.getLogger('example')
11 | handler = logging.StreamHandler()
12 | handler.setFormatter(file_formatter)
13 | logger.addHandler(handler)
14 |
15 | file_handle_name = "file"
16 | if file_handle_name in [h.name for h in logger.handlers]:
17 | return
18 | if os.path.dirname(filepath) is not '':
19 | if not os.path.isdir(os.path.dirname(filepath)):
20 | os.makedirs(os.path.dirname(filepath))
21 | file_handle = logging.FileHandler(filename=filepath, mode="a")
22 | file_handle.set_name(file_handle_name)
23 | file_handle.setFormatter(file_formatter)
24 | logger.addHandler(file_handle)
25 | logger.setLevel(logging.DEBUG)
26 | return logger
--------------------------------------------------------------------------------
/dataloader/KITTI_submission_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 |
4 | IMG_EXTENSIONS = [
5 | '.jpg', '.JPG', '.jpeg', '.JPEG',
6 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
7 | ]
8 |
9 |
10 | def is_image_file(filename):
11 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
12 |
13 |
14 | def dataloader2015(filepath):
15 |
16 | left_fold = 'image_2/'
17 | right_fold = 'image_3/'
18 |
19 |
20 | image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
21 |
22 |
23 | left_test = [filepath+left_fold+img for img in image]
24 | right_test = [filepath+right_fold+img for img in image]
25 |
26 | return left_test, right_test
27 |
28 |
29 | def dataloader2012(filepath):
30 |
31 | left_fold = 'colored_0/'
32 | right_fold = 'colored_1/'
33 |
34 |
35 | image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
36 |
37 |
38 | left_test = [filepath+left_fold+img for img in image]
39 | right_test = [filepath+right_fold+img for img in image]
40 |
41 | return left_test, right_test
42 |
--------------------------------------------------------------------------------
/utils/readpfm.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 | import sys
4 |
5 |
6 | def readPFM(file):
7 | file = open(file, 'rb')
8 |
9 | color = None
10 | width = None
11 | height = None
12 | scale = None
13 | endian = None
14 |
15 | header = file.readline().rstrip()
16 | if header == 'PF':
17 | color = True
18 | elif header == 'Pf':
19 | color = False
20 | else:
21 | raise Exception('Not a PFM file.')
22 |
23 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline())
24 | if dim_match:
25 | width, height = map(int, dim_match.groups())
26 | else:
27 | raise Exception('Malformed PFM header.')
28 |
29 | scale = float(file.readline().rstrip())
30 | if scale < 0: # little-endian
31 | endian = '<'
32 | scale = -scale
33 | else:
34 | endian = '>' # big-endian
35 |
36 | data = np.fromfile(file, endian + 'f')
37 | shape = (height, width, 3) if color else (height, width)
38 |
39 | data = np.reshape(data, shape)
40 | data = np.flipud(data)
41 | return data, scale
42 |
43 |
--------------------------------------------------------------------------------
/dataloader/readpfm.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 | import sys
4 |
5 |
6 | def readPFM(file):
7 | file = open(file, 'rb')
8 |
9 | color = None
10 | width = None
11 | height = None
12 | scale = None
13 | endian = None
14 |
15 | header = file.readline().rstrip()
16 | if header == b'PF':
17 | color = True
18 | elif header == b'Pf':
19 | color = False
20 | else:
21 | raise Exception('Not a PFM file.')
22 |
23 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
24 | if dim_match:
25 | width, height = map(int, dim_match.groups())
26 | else:
27 | raise Exception('Malformed PFM header.')
28 |
29 | scale = float(file.readline().rstrip())
30 | if scale < 0: # little-endian
31 | endian = '<'
32 | scale = -scale
33 | else:
34 | endian = '>' # big-endian
35 |
36 | data = np.fromfile(file, endian + 'f')
37 | shape = (height, width, 3) if color else (height, width)
38 |
39 | data = np.reshape(data, shape)
40 | data = np.flipud(data)
41 | file.close()
42 | return data, scale
43 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LWANet
2 | This repository contains the codes for our paper: [Light-weight Network for Real-time Adaptive Stereo Depth Estimation](https://www.sciencedirect.com/science/article/pii/S0925231221002599)
3 |
4 | # Abstract
5 | Self-supervised learning methods have been proved effective in the task of real-time stereo
6 | depth estimation with the requirement of lower memory space and less computational cost. In this
7 | paper, a light-weight adaptive network (LWANet) is proposed by combining the self-supervised
8 | learning method to perform online adaptive stereo depth estimation for low computation cost and
9 | low GPU memory space. Instead of a regular 3D convolution, the pseudo 3D convolution is
10 | employed in the proposed light-weight network to aggregate the cost volume for achieving a better
11 | balance between the accuracy and the computational cost. Moreover, based on U-Net architecture,
12 | the downsample feature extractor is combined with a refined convolutional spatial propagation
13 | network (CSPN) to further refine the estimation accuracy with little memory space and
14 | computational cost. Extensive experiments demonstrate that the proposed LWANet effectively
15 | alleviates the domain shift problem by online updating the neural network, which is suitable for
16 | embedded devices such as NVIDIA Jetson TX2.
17 |
18 | # Usage
19 |
20 | To be updated
21 |
22 | # Citation
23 |
24 | If you find this is useful, wecome to cite with
25 |
26 | ```
27 | Gan, W., Wong, P. K., Yu, G., Zhao, R., & Vong, C. M. (2021). Light-weight Network for Real-time Adaptive Stereo Depth Estimation. Neurocomputing.
28 | ```
29 |
30 | # Acknowledgement
31 |
32 | Many thanks to authors of [AnyNet](https://github.com/mileyan/AnyNet) and [CSPN](https://github.com/XinJCheng/CSPN) for open-sourcing the code.
33 |
--------------------------------------------------------------------------------
/dataloader/KITTILoader_One_cycle.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image, ImageOps
3 | import numpy as np
4 | from . import preprocess
5 |
6 | IMG_EXTENSIONS = [
7 | '.jpg', '.JPG', '.jpeg', '.JPEG',
8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
9 | ]
10 |
11 | def is_image_file(filename):
12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
13 |
14 | def default_loader(path):
15 | return Image.open(path).convert('RGB')
16 |
17 | def disparity_loader(path):
18 | return Image.open(path)
19 |
20 |
21 | class myImageFloder(data.Dataset):
22 | def __init__(self, left, right, left_disparity, training, loader=default_loader, dploader= disparity_loader):
23 |
24 | self.left = left
25 | self.right = right
26 | self.disp_L = left_disparity
27 | self.loader = loader
28 | self.dploader = dploader
29 | self.training = training
30 |
31 | def __getitem__(self, index):
32 | left = self.left[index]
33 | right = self.right[index]
34 | disp_L= self.disp_L[index]
35 |
36 | left_img = self.loader(left)
37 | right_img = self.loader(right)
38 | dataL = self.dploader(disp_L)
39 |
40 |
41 |
42 | w, h = left_img.size
43 |
44 | left_img = left_img.crop((w - 1216, h - 320, w, h))
45 | right_img = right_img.crop((w - 1216, h - 320, w, h))
46 |
47 |
48 |
49 | dataL = dataL.crop((w-1216, h-320, w, h))
50 |
51 | dataL = np.ascontiguousarray(dataL,dtype=np.float32)/256
52 |
53 | processed = preprocess.get_transform(augment=False)
54 | left_img = processed(left_img)
55 | right_img = processed(right_img)
56 |
57 |
58 |
59 | return left_img, right_img, dataL
60 |
61 |
62 |
63 |
64 |
65 |
66 | def __len__(self):
67 | return len(self.left)
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/dataloader/KITTILoader_0028_0071.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image, ImageOps
3 | import numpy as np
4 | from . import preprocess
5 |
6 | IMG_EXTENSIONS = [
7 | '.jpg', '.JPG', '.jpeg', '.JPEG',
8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
9 | ]
10 |
11 | def is_image_file(filename):
12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
13 |
14 | def default_loader(path):
15 | return Image.open(path).convert('RGB')
16 |
17 | def disparity_loader(path):
18 | return Image.open(path)
19 |
20 |
21 | class myImageFloder(data.Dataset):
22 | def __init__(self, left, right, left_disparity, training, loader=default_loader, dploader= disparity_loader):
23 |
24 | self.left = left
25 | self.right = right
26 | self.disp_L = left_disparity
27 | self.loader = loader
28 | self.dploader = dploader
29 | self.training = training
30 |
31 | def __getitem__(self, index):
32 | left = self.left[index]
33 | right = self.right[index]
34 | disp_L= self.disp_L[index]
35 |
36 | left_img = self.loader(left)
37 | right_img = self.loader(right)
38 | dataL = self.dploader(disp_L)
39 |
40 |
41 | # full image
42 | w, h = left_img.size
43 | left_img = left_img.crop((w - 1216, h - 320, w, h))
44 | right_img = right_img.crop((w - 1216, h - 320, w, h))
45 | dataL = dataL.crop((w - 1216, h - 320, w, h))
46 | dataL = np.ascontiguousarray(dataL, dtype=np.float32) / 256
47 |
48 | # 0028
49 | #dataL = 0.54 * 707 / dataL
50 |
51 | # 0071
52 | dataL = 0.54 * 718 / dataL
53 |
54 | processed = preprocess.get_transform(augment=False)
55 | left_img = processed(left_img)
56 | right_img = processed(right_img)
57 |
58 |
59 | return left_img, right_img, dataL
60 |
61 |
62 | def __len__(self):
63 | return len(self.left)
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/dataloader/KITTIloader2015_One_cycle.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath, log):
18 |
19 | left_fold = 'image_2/'
20 | right_fold = 'image_3/'
21 | disp_L = 'disp_occ_0/'
22 | #disp_R = 'disp_occ_1/'
23 | #
24 | # left_fold = 'image_02/data/'
25 | # right_fold = 'image_03/data/'
26 | # disp_L = 'data_depth_annotated/2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/'
27 | #disp_R = 'disp_occ_1/'
28 |
29 | image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
30 | #image = [img for img in os.listdir(filepath + left_fold) if img.find('000000_10')]
31 | #print('image 0:', len(image))
32 |
33 | all_index = np.arange(200)
34 | #np.random.seed(2)
35 | #np.random.shuffle(all_index)
36 | #print('all_index:', all_index)
37 | vallist = all_index[:40]
38 |
39 | log.info(vallist)
40 | val = ['{:06d}_10.png'.format(x) for x in vallist]
41 | #train = [x for x in image if x not in val]
42 | train = [x for x in image if x == '000128_10.png']
43 | print('train :', train[0])
44 |
45 |
46 |
47 |
48 |
49 | left_train = [filepath+left_fold+img for img in train]
50 | right_train = [filepath+right_fold+img for img in train]
51 | disp_train_L = [filepath+disp_L+img for img in train]
52 | #disp_train_R = [filepath+disp_R+img for img in train]
53 |
54 | left_val = [filepath+left_fold+img for img in val]
55 | right_val = [filepath+right_fold+img for img in val]
56 | disp_val_L = [filepath+disp_L+img for img in val]
57 | #disp_val_R = [filepath+disp_R+img for img in val]
58 |
59 |
60 | return left_train, right_train, disp_train_L, left_val, right_val, disp_val_L
61 | #return left_train, right_train, disp_train_L, disp_train_R, left_val, right_val, disp_val_L, disp_val_R
--------------------------------------------------------------------------------
/dataloader/SecenFlowLoader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data as data
4 | import torch
5 | import torchvision.transforms as transforms
6 | import random
7 | from PIL import Image, ImageOps
8 | from . import preprocess
9 | from . import listflowfile as lt
10 | from . import readpfm as rp
11 | import numpy as np
12 |
13 | IMG_EXTENSIONS = [
14 | '.jpg', '.JPG', '.jpeg', '.JPEG',
15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16 | ]
17 |
18 |
19 | def is_image_file(filename):
20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
21 |
22 |
23 | def default_loader(path):
24 | return Image.open(path).convert('RGB')
25 |
26 |
27 | def disparity_loader(path):
28 | return rp.readPFM(path)
29 |
30 |
31 | class myImageFloder(data.Dataset):
32 | def __init__(self, left, right, left_disparity, training, loader=default_loader, dploader=disparity_loader):
33 |
34 | self.left = left
35 | self.right = right
36 | self.disp_L = left_disparity
37 | self.loader = loader
38 | self.dploader = dploader
39 | self.training = training
40 |
41 | def __getitem__(self, index):
42 | left = self.left[index]
43 | right = self.right[index]
44 | disp_L = self.disp_L[index]
45 |
46 | left_img = self.loader(left)
47 | right_img = self.loader(right)
48 | dataL, scaleL = self.dploader(disp_L)
49 | dataL = np.ascontiguousarray(dataL, dtype=np.float32)
50 |
51 | if self.training:
52 | w, h = left_img.size
53 | #th, tw = 256, 512
54 | th, tw = 512, 960
55 |
56 | x1 = random.randint(0, w - tw)
57 | y1 = random.randint(0, h - th)
58 |
59 | left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
60 | right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))
61 |
62 | dataL = dataL[y1:y1 + th, x1:x1 + tw]
63 |
64 | processed = preprocess.get_transform(augment=False)
65 | left_img = processed(left_img)
66 | right_img = processed(right_img)
67 |
68 | return left_img, right_img, dataL
69 | else:
70 | w, h = left_img.size
71 | left_img = left_img.crop((w - 960, h - 544, w, h))
72 | right_img = right_img.crop((w - 960, h - 544, w, h))
73 | processed = preprocess.get_transform(augment=False)
74 | left_img = processed(left_img)
75 | right_img = processed(right_img)
76 |
77 | return left_img, right_img, dataL
78 |
79 | def __len__(self):
80 | return len(self.left)
81 |
--------------------------------------------------------------------------------
/dataloader/KITTILoader.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.utils.data as data
3 | import random
4 | from PIL import Image, ImageOps
5 | import numpy as np
6 | from . import preprocess
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 | def is_image_file(filename):
14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
15 |
16 | def default_loader(path):
17 | return Image.open(path).convert('RGB')
18 |
19 | def disparity_loader(path):
20 | return Image.open(path)
21 |
22 |
23 | class myImageFloder(data.Dataset):
24 | def __init__(self, left, right, left_disparity, training, loader=default_loader, dploader= disparity_loader):
25 |
26 | self.left = left
27 | self.right = right
28 | self.disp_L = left_disparity
29 | self.loader = loader
30 | self.dploader = dploader
31 | self.training = training
32 |
33 | def __getitem__(self, index):
34 | left = self.left[index]
35 | right = self.right[index]
36 | disp_L= self.disp_L[index]
37 |
38 | left_img = self.loader(left)
39 | right_img = self.loader(right)
40 | dataL = self.dploader(disp_L)
41 |
42 |
43 | if self.training:
44 | w, h = left_img.size
45 | #print(' w, h:', w, h)
46 | #th, tw = 256, 512
47 | th, tw = 288, 624
48 |
49 |
50 | x1 = random.randint(0, w - tw)
51 |
52 | y1 = random.randint(0, h - th)
53 |
54 | left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
55 | right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))
56 |
57 |
58 |
59 | dataL = np.ascontiguousarray(dataL,dtype=np.float32)/256
60 | dataL = dataL[y1:y1 + th, x1:x1 + tw]
61 |
62 | processed = preprocess.get_transform(augment=False)
63 | left_img = processed(left_img)
64 | right_img = processed(right_img)
65 |
66 | return left_img, right_img, dataL
67 |
68 |
69 |
70 |
71 | else:
72 | w, h = left_img.size
73 |
74 | left_img = left_img.crop((w-1232, h-368, w, h))
75 | right_img = right_img.crop((w-1232, h-368, w, h))
76 |
77 | dataL = dataL.crop((w-1232, h-368, w, h))
78 |
79 | dataL = np.ascontiguousarray(dataL,dtype=np.float32)/256
80 |
81 | processed = preprocess.get_transform(augment=False)
82 | left_img = processed(left_img)
83 | right_img = processed(right_img)
84 |
85 |
86 |
87 | return left_img, right_img, dataL
88 |
89 |
90 | def __len__(self):
91 | return len(self.left)
92 |
93 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 | 1610205822897
39 |
40 |
41 | 1610205822897
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/dataloader/KITTIdatalist.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import random
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader2012(filepath, log, split=False):
19 |
20 | left_fold = 'colored_0/'
21 | right_fold = 'colored_1/'
22 | disp_noc = 'disp_occ/'
23 |
24 | image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
25 | random.shuffle(image)
26 |
27 |
28 | if not split:
29 |
30 | np.random.seed(2)
31 | random.shuffle(image)
32 | train = image[:]
33 | val = image[160:]
34 |
35 | else:
36 |
37 | train = image[:160]
38 | val = image[160:]
39 |
40 |
41 |
42 | log.info(val)
43 |
44 | left_train = [filepath+left_fold+img for img in train]
45 | right_train = [filepath+right_fold+img for img in train]
46 | disp_train = [filepath+disp_noc+img for img in train]
47 |
48 |
49 | left_val = [filepath+left_fold+img for img in val]
50 | right_val = [filepath+right_fold+img for img in val]
51 | disp_val = [filepath+disp_noc+img for img in val]
52 |
53 | return left_train, right_train, disp_train, left_val, right_val, disp_val
54 |
55 |
56 |
57 | def dataloader2015(filepath, log, split = False):
58 |
59 | left_fold = 'image_2/'
60 | right_fold = 'image_3/'
61 | disp_L = 'disp_occ_0/'
62 |
63 | image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
64 |
65 | all_index = np.arange(200)
66 | np.random.seed(2)
67 | np.random.shuffle(all_index)
68 | #print('all_index:', all_index)
69 | vallist = all_index[:40]
70 |
71 | log.info(vallist)
72 | val = ['{:06d}_10.png'.format(x) for x in vallist]
73 |
74 | if split:
75 | train = [x for x in image if x not in val]
76 | # train = [x for x in image if x not in val]
77 |
78 | else:
79 | train = [x for x in image]
80 |
81 |
82 |
83 | left_train = [filepath+left_fold+img for img in train]
84 | right_train = [filepath+right_fold+img for img in train]
85 | disp_train_L = [filepath+disp_L+img for img in train]
86 | #disp_train_R = [filepath+disp_R+img for img in train]
87 |
88 | left_val = [filepath+left_fold+img for img in val]
89 | right_val = [filepath+right_fold+img for img in val]
90 | disp_val_L = [filepath+disp_L+img for img in val]
91 | #disp_val_R = [filepath+disp_R+img for img in val]
92 |
93 |
94 | return left_train, right_train, disp_train_L, left_val, right_val, disp_val_L
95 |
96 |
97 |
98 |
99 | def dataloader_adaptation(filepath, datatype):
100 |
101 | # 0028
102 | left_fold = 'raw_image/image_02/data/'
103 | right_fold = 'raw_image/image_03/data/' # w, h: 1226 370
104 | disp_L = 'disparity/image_02/'
105 |
106 | path_list = os.listdir(filepath + left_fold)
107 | path_list.sort(key=lambda x: int(x.split('.')[0]))
108 | image = [img for img in path_list]
109 |
110 |
111 | #0028
112 | if datatype == "0028":
113 | image = image[5:2005]
114 |
115 |
116 | elif datatype == "0071":
117 | # 0071
118 | image = image[5:-6]
119 |
120 |
121 | train = [x for x in image]
122 |
123 | left_train = [filepath+left_fold+img for img in train]
124 | right_train = [filepath+right_fold+img for img in train]
125 | disp_train_L = [filepath+disp_L+img for img in train]
126 |
127 | return left_train, right_train,
--------------------------------------------------------------------------------
/models/feature_extraction.py:
--------------------------------------------------------------------------------
1 | #coding=utf-8
2 | from __future__ import print_function
3 | import torch.nn as nn
4 |
5 |
6 | class F1(nn.Module):
7 | def __init__(self):
8 | super(F1, self).__init__()
9 | # feature extraction
10 | self.init_feature = nn.Sequential(
11 |
12 | # 6-24
13 | nn.Conv2d(3, 4, 3, 1, 1, bias=False),
14 | nn.BatchNorm2d(4),
15 | nn.ELU(inplace=True),
16 | nn.Conv2d(4, 4, 3, 2, 1, bias=False),
17 | nn.Conv2d(4, 8, 3, 1, 1, bias=False),
18 |
19 | )
20 |
21 | def forward(self, x_left):
22 |
23 | buffer_left = self.init_feature(x_left)
24 |
25 | return buffer_left
26 |
27 |
28 |
29 | class F2(nn.Module):
30 | def __init__(self):
31 | super(F2, self).__init__()
32 |
33 | self.init_feature = nn.Sequential(
34 |
35 |
36 | nn.MaxPool2d(2, 2),
37 | nn.BatchNorm2d(8),
38 | nn.ELU(inplace=True),
39 |
40 | nn.Conv2d(8, 12, 3, 1, 1, bias=False),
41 | nn.BatchNorm2d(12),
42 | nn.ELU(inplace=True),
43 | nn.Conv2d(12, 12, 3, 1, 1, bias=False),
44 |
45 | )
46 |
47 | def forward(self, x_left):
48 |
49 | buffer_left = self.init_feature(x_left)
50 |
51 | return buffer_left
52 |
53 |
54 | class F3(nn.Module):
55 | def __init__(self):
56 | super(F3, self).__init__()
57 |
58 | self.init_feature = nn.Sequential(
59 |
60 | nn.MaxPool2d(2, 2),
61 | nn.BatchNorm2d(12),
62 | nn.ELU(inplace=True),
63 |
64 | nn.Conv2d(12, 16, 3, 1, 1, bias=False),
65 | nn.BatchNorm2d(16),
66 | nn.ELU(inplace=True),
67 | nn.Conv2d(16, 16, 3, 1, 1, bias=False),
68 |
69 | )
70 |
71 | def forward(self, x_left):
72 |
73 | buffer_left = self.init_feature(x_left)
74 |
75 | return buffer_left
76 |
77 |
78 |
79 |
80 | class F3_UP(nn.Module):
81 | def __init__(self):
82 | super(F3_UP, self).__init__()
83 | self.init_feature = nn.Sequential(
84 |
85 | nn.Conv2d(16, 16, 3, 1, 1, bias=False),
86 |
87 | nn.BatchNorm2d(16),
88 |
89 | nn.ELU(inplace=True),
90 |
91 | nn.ConvTranspose2d(16, 12, 3, 2, 1, output_padding=1, bias=False),
92 | )
93 |
94 | def forward(self, x_left):
95 |
96 | buffer_left = self.init_feature(x_left)
97 |
98 | return buffer_left
99 |
100 |
101 | class F2_UP(nn.Module):
102 | def __init__(self):
103 | super(F2_UP, self).__init__()
104 |
105 | # cat
106 | self.init_feature = nn.Sequential(
107 |
108 | nn.BatchNorm2d(24),
109 | nn.ELU(inplace=True),
110 | nn.ConvTranspose2d(24, 8, 3, 2, 1, output_padding=1, bias=False),
111 | )
112 |
113 |
114 |
115 | def forward(self, x_left):
116 | ### feature extraction
117 | buffer_left = self.init_feature(x_left)
118 |
119 | return buffer_left
120 |
121 |
122 | class F1_UP(nn.Module):
123 | def __init__(self):
124 | super(F1_UP, self).__init__()
125 | # cat
126 | self.init_feature = nn.Sequential(
127 |
128 | nn.BatchNorm2d(16),
129 | nn.ELU(inplace=True),
130 | nn.ConvTranspose2d(16, 8, 3, 2, 1, output_padding=1, bias=False),
131 |
132 | )
133 |
134 |
135 | def forward(self, x_left):
136 | ### feature extraction
137 | buffer_left = self.init_feature(x_left)
138 |
139 | return buffer_left
140 |
141 |
142 |
143 |
--------------------------------------------------------------------------------
/env.yaml:
--------------------------------------------------------------------------------
1 | name: LWANet
2 | channels:
3 | - pytorch
4 | - anaconda
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - blas=1.0=mkl
9 | - bzip2=1.0.8=h7b6447c_0
10 | - ca-certificates=2020.6.24=0
11 | - cairo=1.14.12=h8948797_3
12 | - certifi=2020.6.20=py37_0
13 | - cffi=1.12.3=py37h2e261b9_0
14 | - cloudpickle=1.2.1=py_0
15 | - cuda80=1.0=h205658b_0
16 | - cudatoolkit=10.0.130=0
17 | - cycler=0.10.0=py37_0
18 | - cytoolz=0.10.0=py37h7b6447c_0
19 | - dask-core=2.3.0=py_0
20 | - dbus=1.13.6=h746ee38_0
21 | - decorator=4.4.0=py37_1
22 | - expat=2.2.6=he6710b0_0
23 | - ffmpeg=4.0=hcdf2ecd_0
24 | - fontconfig=2.13.0=h9420a91_0
25 | - freeglut=3.0.0=hf484d3e_5
26 | - freetype=2.9.1=h8a8886c_1
27 | - glib=2.56.2=hd408876_0
28 | - graphite2=1.3.13=h23475e2_0
29 | - gst-plugins-base=1.14.0=hbbd80ab_1
30 | - gstreamer=1.14.0=hb453b48_1
31 | - harfbuzz=1.8.8=hffaf4a1_0
32 | - hdf5=1.10.2=hba1933b_1
33 | - icu=58.2=h9c2bf20_1
34 | - imageio=2.5.0=py37_0
35 | - intel-openmp=2019.4=243
36 | - jasper=2.0.14=h07fcdf6_1
37 | - jpeg=9b=h024ee3a_2
38 | - kiwisolver=1.1.0=py37he6710b0_0
39 | - libedit=3.1.20181209=hc058e9b_0
40 | - libffi=3.2.1=hd88cf55_4
41 | - libgcc-ng=9.1.0=hdf63c60_0
42 | - libgfortran-ng=7.3.0=hdf63c60_0
43 | - libglu=9.0.0=hf484d3e_1
44 | - libopencv=3.4.2=hb342d67_1
45 | - libopus=1.3=h7b6447c_0
46 | - libpng=1.6.37=hbc83047_0
47 | - libstdcxx-ng=9.1.0=hdf63c60_0
48 | - libtiff=4.0.10=h2733197_2
49 | - libuuid=1.0.3=h1bed415_2
50 | - libvpx=1.7.0=h439df22_0
51 | - libxcb=1.13=h1bed415_1
52 | - libxml2=2.9.9=hea5a465_1
53 | - matplotlib=3.1.1=py37h5429711_0
54 | - mkl=2019.4=243
55 | - mkl-service=2.0.2=py37h7b6447c_0
56 | - mkl_fft=1.0.14=py37ha843d7b_0
57 | - mkl_random=1.0.2=py37hd81dba3_0
58 | - ncurses=6.1=he6710b0_1
59 | - networkx=2.3=py_0
60 | - ninja=1.9.0=py37hfd86e86_0
61 | - olefile=0.46=py37_0
62 | - opencv=3.4.2=py37h6fd60c2_1
63 | - openssl=1.1.1g=h7b6447c_0
64 | - pcre=8.43=he6710b0_0
65 | - pillow=6.1.0=py37h34e0f95_0
66 | - pip=19.2.2=py37_0
67 | - pixman=0.38.0=h7b6447c_0
68 | - py-opencv=3.4.2=py37hb342d67_1
69 | - pycparser=2.19=py37_0
70 | - pyparsing=2.4.2=py_0
71 | - pyqt=5.9.2=py37h05f1152_2
72 | - python=3.7.4=h265db76_1
73 | - python-dateutil=2.8.0=py37_0
74 | - pytorch=1.0.0=py3.7_cuda8.0.61_cudnn7.1.2_1
75 | - pytz=2019.2=py_0
76 | - pywavelets=1.0.3=py37hdd07704_1
77 | - qt=5.9.7=h5867ecd_1
78 | - readline=7.0=h7b6447c_5
79 | - scikit-image=0.15.0=py37he6710b0_0
80 | - scipy=1.3.1=py37h7c811a0_0
81 | - setuptools=41.0.1=py37_0
82 | - sip=4.19.8=py37hf484d3e_0
83 | - six=1.12.0=py37_0
84 | - sqlite=3.29.0=h7b6447c_0
85 | - tk=8.6.8=hbc83047_0
86 | - toolz=0.10.0=py_0
87 | - torchvision=0.2.1=py_2
88 | - tornado=6.0.3=py37h7b6447c_0
89 | - wheel=0.33.4=py37_0
90 | - xz=5.2.4=h14c3975_4
91 | - zlib=1.2.11=h7b6447c_3
92 | - zstd=1.3.7=h0b5b093_0
93 | - pip:
94 | - absl-py==0.8.0
95 | - apex==0.1
96 | - astor==0.8.0
97 | - bleach==1.5.0
98 | - chardet==3.0.4
99 | - future==0.17.1
100 | - gast==0.2.2
101 | - google-pasta==0.1.7
102 | - grpcio==1.23.0
103 | - h5py==2.9.0
104 | - html5lib==0.9999999
105 | - idna==2.10
106 | - imageio-ffmpeg==0.4.2
107 | - keras-applications==1.0.8
108 | - keras-preprocessing==1.1.0
109 | - markdown==3.1.1
110 | - moviepy==1.0.3
111 | - numpy==1.19.4
112 | - proglog==0.1.9
113 | - protobuf==3.9.1
114 | - requests==2.25.1
115 | - tb-nightly==1.15.0a20190902
116 | - termcolor==1.1.0
117 | - thop==0.0.31-2005241907
118 | - tqdm==4.54.1
119 | - urllib3==1.26.2
120 | - werkzeug==0.15.5
121 | - wrapt==1.11.2
122 | - yum==0.0.1
123 | prefix: /home/wsgan/anaconda3/envs/aanet
124 |
--------------------------------------------------------------------------------
/dataloader/listflowfile.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 |
7 | IMG_EXTENSIONS = [
8 | '.jpg', '.JPG', '.jpeg', '.JPEG',
9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
10 | ]
11 |
12 |
13 | def is_image_file(filename):
14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
15 |
16 | def dataloader(filepath):
17 |
18 | classes = [d for d in os.listdir(filepath) if os.path.isdir(os.path.join(filepath, d))]
19 | image = [img for img in classes if img.find('frames_cleanpass') > -1]
20 | disp = [dsp for dsp in classes if dsp.find('disparity') > -1]
21 |
22 | print('len image:',len(image))
23 |
24 | monkaa_path = filepath + '' + [x for x in image if 'monkaa' in x][0]
25 | monkaa_disp = filepath + [x for x in disp if 'monkaa' in x][0]
26 | # monkaa_path = filepath + 'monkaa_frames_cleanpass'
27 | # monkaa_disp = filepath + 'monkaa_disparity'
28 |
29 | monkaa_dir = os.listdir(monkaa_path)
30 |
31 |
32 | all_left_img=[]
33 | all_right_img=[]
34 | all_left_disp = []
35 | test_left_img=[]
36 | test_right_img=[]
37 | test_left_disp = []
38 |
39 |
40 | for dd in monkaa_dir:
41 |
42 | for im in os.listdir(monkaa_path+'/'+dd+'/left/'):
43 | if is_image_file(monkaa_path+'/'+dd+'/left/'+im):
44 | all_left_img.append(monkaa_path+'/'+dd+'/left/'+im)
45 | all_left_disp.append(monkaa_disp+'/'+dd+'/left/'+im.split(".")[0]+'.pfm')
46 |
47 | for im in os.listdir(monkaa_path+'/'+dd+'/right/'):
48 | if is_image_file(monkaa_path+'/'+dd+'/right/'+im):
49 | all_right_img.append(monkaa_path+'/'+dd+'/right/'+im)
50 |
51 |
52 |
53 |
54 |
55 | flying_path = filepath + [x for x in image if x == 'frames_cleanpass'][0]
56 | flying_disp = filepath + [x for x in disp if x == 'frames_disparity'][0]
57 | flying_dir = flying_path+'/TRAIN/'
58 | subdir = ['A','B','C']
59 |
60 | for ss in subdir:
61 | flying = os.listdir(flying_dir+ss)
62 |
63 | for ff in flying:
64 | imm_l = os.listdir(flying_dir+ss+'/'+ff+'/left/')
65 | for im in imm_l:
66 | if is_image_file(flying_dir+ss+'/'+ff+'/left/'+im):
67 | all_left_img.append(flying_dir+ss+'/'+ff+'/left/'+im)
68 |
69 | all_left_disp.append(flying_disp+'/TRAIN/'+ss+'/'+ff+'/left/'+im.split(".")[0]+'.pfm')
70 |
71 | if is_image_file(flying_dir+ss+'/'+ff+'/right/'+im):
72 | all_right_img.append(flying_dir+ss+'/'+ff+'/right/'+im)
73 |
74 | flying_dir = flying_path+'/TEST/'
75 |
76 | subdir = ['A','B','C']
77 |
78 |
79 | for ss in subdir:
80 | flying = os.listdir(flying_dir+ss)
81 |
82 | for ff in flying:
83 | imm_l = os.listdir(flying_dir+ss+'/'+ff+'/left/')
84 | for im in imm_l:
85 | if is_image_file(flying_dir+ss+'/'+ff+'/left/'+im):
86 | test_left_img.append(flying_dir+ss+'/'+ff+'/left/'+im)
87 |
88 | test_left_disp.append(flying_disp+'/TEST/'+ss+'/'+ff+'/left/'+im.split(".")[0]+'.pfm')
89 |
90 | if is_image_file(flying_dir+ss+'/'+ff+'/right/'+im):
91 | test_right_img.append(flying_dir+ss+'/'+ff+'/right/'+im)
92 |
93 |
94 |
95 |
96 |
97 | driving_dir = filepath + [x for x in image if 'driving' in x][0] + '/'
98 | driving_disp = filepath + [x for x in disp if 'driving' in x][0]
99 |
100 | subdir1 = ['35mm_focallength','15mm_focallength']
101 | subdir2 = ['scene_backwards','scene_forwards']
102 | subdir3 = ['fast','slow']
103 |
104 |
105 | for i in subdir1:
106 | for j in subdir2:
107 | for k in subdir3:
108 | imm_l = os.listdir(driving_dir+i+'/'+j+'/'+k+'/left/')
109 | for im in imm_l:
110 | if is_image_file(driving_dir+i+'/'+j+'/'+k+'/left/'+im):
111 | all_left_img.append(driving_dir+i+'/'+j+'/'+k+'/left/'+im)
112 | all_left_disp.append(driving_disp+'/'+i+'/'+j+'/'+k+'/left/'+im.split(".")[0]+'.pfm')
113 |
114 | if is_image_file(driving_dir+i+'/'+j+'/'+k+'/right/'+im):
115 | all_right_img.append(driving_dir+i+'/'+j+'/'+k+'/right/'+im)
116 |
117 |
118 | print('len all_left_img:', len(all_left_img))
119 | print('len test_left_img:', len(test_left_img))
120 |
121 | return all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp
122 |
123 |
124 |
--------------------------------------------------------------------------------
/utils/merge_img2video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | from PIL import Image, ImageDraw, ImageFont
4 | import numpy as np
5 | import pdb
6 |
7 | def merge_img2video(image_path, video_path):
8 | # path = image_path # 图片序列所在目录,文件名:0.jpg 1.jpg ...
9 | # dst_path = video_path #r'F:\dst\result.mp4' # 生成的视频路径
10 |
11 | filelist = os.listdir(image_path)
12 | filepref = [os.path.splitext(f)[0] for f in filelist]
13 |
14 |
15 |
16 | filepref.sort(key = int) # 按数字文件名排序
17 | #filepref= sorted(filepref,key=lambda x: int(x[:-6])) # 按数字文件名排序
18 |
19 | #pdb.set_trace()
20 |
21 | filelist = [f + '.png' for f in filepref]
22 |
23 | # size = (int(videoCapture.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)),
24 | # int(videoCapture.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)))
25 | width = 1216
26 | height = 320
27 |
28 | # width = 1238
29 | # height = 374
30 | fps = 30
31 |
32 | vw = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'DIVX'), fps, (width , height))
33 |
34 | #for file in filelist[5:-6]:
35 | for file in filelist:
36 | if file.endswith('.png'):
37 | file = os.path.join(image_path, file)
38 | print("file:", file)
39 | img = cv2.imread(file)
40 | print("img:", img.shape)
41 |
42 |
43 | # img = img[54:,22:,:]
44 | # #img = img[50:, 10:, :]
45 |
46 |
47 | print("img:", img.shape)
48 | #img = np.hstack((img, img)) # 如果并排两列显示
49 | vw.write(img)
50 |
51 |
52 | vw.release()
53 |
54 |
55 |
56 | def merge_video():
57 |
58 | videoLeftUp = cv2.VideoCapture('/home/wsgan/LWANet/results/video/0028/raw_img_title.mp4')
59 | videoLeftDown = cv2.VideoCapture('/home/wsgan/LWANet/results/video/0028/GT_supervise/GT_supervise_subtitile.mp4')
60 | videoRightUp = cv2.VideoCapture('/home/wsgan/LWANet/results/video/0028/self_supervise/Self_supervise_subtitle.mp4')
61 | videoRightDown = cv2.VideoCapture('/home/wsgan/LWANet/results/video/0028/no_supervise/No_supervise_subtitle.mp4')
62 |
63 | fps = videoLeftUp.get(cv2.CAP_PROP_FPS)
64 |
65 | width = (int(videoLeftUp.get(cv2.CAP_PROP_FRAME_WIDTH)))
66 | height = (int(videoLeftUp.get(cv2.CAP_PROP_FRAME_HEIGHT)))
67 |
68 | videoWriter = cv2.VideoWriter('/home/wsgan/LWANet/results/video/0028/merge0028.mp4', cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), fps, (width, height))
69 |
70 | successLeftUp, frameLeftUp = videoLeftUp.read()
71 | successLeftDown, frameLeftDown = videoLeftDown.read()
72 | successRightUp, frameRightUp = videoRightUp.read()
73 | successRightDown, frameRightDown = videoRightDown.read()
74 |
75 | while successLeftUp and successLeftDown and successRightUp and successRightDown:
76 | frameLeftUp = cv2.resize(frameLeftUp, (int(width / 2), int(height / 2)), interpolation=cv2.INTER_CUBIC)
77 | frameLeftDown = cv2.resize(frameLeftDown, (int(width / 2), int(height / 2)), interpolation=cv2.INTER_CUBIC)
78 | frameRightUp = cv2.resize(frameRightUp, (int(width / 2), int(height / 2)), interpolation=cv2.INTER_CUBIC)
79 | frameRightDown = cv2.resize(frameRightDown, (int(width / 2), int(height / 2)), interpolation=cv2.INTER_CUBIC)
80 |
81 | frameUp = np.hstack((frameLeftUp, frameRightUp))
82 | frameDown = np.hstack((frameLeftDown, frameRightDown))
83 | frame = np.vstack((frameUp, frameDown))
84 |
85 | videoWriter.write(frame)
86 | successLeftUp, frameLeftUp = videoLeftUp.read()
87 | successLeftDown, frameLeftDown = videoLeftDown.read()
88 | successRightUp, frameRightUp = videoRightUp.read()
89 | successRightDown, frameRightDown = videoRightDown.read()
90 |
91 | videoWriter.release()
92 | videoLeftUp.release()
93 | videoLeftDown.release()
94 | videoRightUp.release()
95 | videoRightDown.release()
96 |
97 |
98 |
99 | def add_subtitle(video_path, save_path):
100 |
101 | cap = cv2.VideoCapture(video_path) # 读取视频
102 |
103 | # Define the codec and create VideoWriter object
104 | fourcc = cv2.VideoWriter_fourcc(*'XVID')
105 | out = cv2.VideoWriter(save_path, fourcc, 30.0, (1216, 320)) # 输出视频参数设置
106 |
107 | while (cap.isOpened()):
108 | ret, frame = cap.read()
109 | if ret == True:
110 | # 在 frame 上显示一些信息
111 | img_PIL = Image.fromarray(frame[..., ::-1]) # 转成 array
112 | font = ImageFont.truetype('UbuntuMono-B.ttf', 40) # 字体设置,Windows系统可以在 "C:\Windows\Fonts" 下查找
113 | text1 = "Self_supervise"
114 |
115 | for i, te in enumerate(text1):
116 | # position = (50, 10 + i * 50)
117 | position = (10 + i * 20, 20 )
118 | draw = ImageDraw.Draw(img_PIL)
119 | draw.text(position, te, font=font, fill=(255, 0, 0))
120 |
121 | frame = cv2.cvtColor(np.asarray(img_PIL), cv2.COLOR_RGB2BGR)
122 |
123 | # write the frame
124 | #cv2.imshow('frame', frame)
125 | out.write(frame)
126 | # if cv2.waitKey(1) & 0xFF == ord('q'):
127 | # break
128 | else:
129 | break
130 |
131 | # Release everything if job is finished
132 | cap.release()
133 | out.release()
134 | #cv2.destroyAllWindows()
135 |
136 |
137 |
138 |
139 |
140 |
141 | if __name__ == "__main__":
142 | print('start')
143 | #merge_img2video('/home/wsgan/LWANet/results/video/0071/self_supervise/disparity', '/home/wsgan/LWANet/results/video/0071/self_supervise/Self_supervise.mp4' )
144 | merge_video()
145 | #add_subtitle('/home/wsgan/LWANet/results/video/0071/self_supervise/Self_supervise.mp4', '/home/wsgan/LWANet/results/video/0071/self_supervise/Self_supervise_subtitle.mp4')
146 | print('end')
--------------------------------------------------------------------------------
/utils/preprocess.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms as transforms
3 | import random
4 |
5 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406],
6 | 'std': [0.229, 0.224, 0.225]}
7 |
8 | #__imagenet_stats = {'mean': [0.5, 0.5, 0.5],
9 | # 'std': [0.5, 0.5, 0.5]}
10 |
11 | __imagenet_pca = {
12 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
13 | 'eigvec': torch.Tensor([
14 | [-0.5675, 0.7192, 0.4009],
15 | [-0.5808, -0.0045, -0.8140],
16 | [-0.5836, -0.6948, 0.4203],
17 | ])
18 | }
19 |
20 |
21 | def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats):
22 | t_list = [
23 | transforms.ToTensor(),
24 | transforms.Normalize(**normalize),
25 | ]
26 | #if scale_size != input_size:
27 | #t_list = [transforms.Scale((960,540))] + t_list
28 |
29 | return transforms.Compose(t_list)
30 |
31 |
32 | def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats):
33 | t_list = [
34 | transforms.RandomCrop(input_size),
35 | transforms.ToTensor(),
36 | transforms.Normalize(**normalize),
37 | ]
38 | if scale_size != input_size:
39 | t_list = [transforms.Scale(scale_size)] + t_list
40 |
41 | transforms.Compose(t_list)
42 |
43 |
44 | def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats):
45 | padding = int((scale_size - input_size) / 2)
46 | return transforms.Compose([
47 | transforms.RandomCrop(input_size, padding=padding),
48 | transforms.RandomHorizontalFlip(),
49 | transforms.ToTensor(),
50 | transforms.Normalize(**normalize),
51 | ])
52 |
53 |
54 | def inception_preproccess(input_size, normalize=__imagenet_stats):
55 | return transforms.Compose([
56 | transforms.RandomSizedCrop(input_size),
57 | transforms.RandomHorizontalFlip(),
58 | transforms.ToTensor(),
59 | transforms.Normalize(**normalize)
60 | ])
61 | def inception_color_preproccess(input_size, normalize=__imagenet_stats):
62 | return transforms.Compose([
63 | #transforms.RandomSizedCrop(input_size),
64 | #transforms.RandomHorizontalFlip(),
65 | transforms.ToTensor(),
66 | ColorJitter(
67 | brightness=0.4,
68 | contrast=0.4,
69 | saturation=0.4,
70 | ),
71 | Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']),
72 | transforms.Normalize(**normalize)
73 | ])
74 |
75 |
76 | def get_transform(name='imagenet', input_size=None,
77 | scale_size=None, normalize=None, augment=True):
78 | normalize = __imagenet_stats
79 | input_size = 256
80 | if augment:
81 | return inception_color_preproccess(input_size, normalize=normalize)
82 | else:
83 | return scale_crop(input_size=input_size,
84 | scale_size=scale_size, normalize=normalize)
85 |
86 |
87 |
88 |
89 | class Lighting(object):
90 | """Lighting noise(AlexNet - style PCA - based noise)"""
91 |
92 | def __init__(self, alphastd, eigval, eigvec):
93 | self.alphastd = alphastd
94 | self.eigval = eigval
95 | self.eigvec = eigvec
96 |
97 | def __call__(self, img):
98 | if self.alphastd == 0:
99 | return img
100 |
101 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
102 | rgb = self.eigvec.type_as(img).clone()\
103 | .mul(alpha.view(1, 3).expand(3, 3))\
104 | .mul(self.eigval.view(1, 3).expand(3, 3))\
105 | .sum(1).squeeze()
106 |
107 | return img.add(rgb.view(3, 1, 1).expand_as(img))
108 |
109 |
110 | class Grayscale(object):
111 |
112 | def __call__(self, img):
113 | gs = img.clone()
114 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2])
115 | gs[1].copy_(gs[0])
116 | gs[2].copy_(gs[0])
117 | return gs
118 |
119 |
120 | class Saturation(object):
121 |
122 | def __init__(self, var):
123 | self.var = var
124 |
125 | def __call__(self, img):
126 | gs = Grayscale()(img)
127 | alpha = random.uniform(0, self.var)
128 | return img.lerp(gs, alpha)
129 |
130 |
131 | class Brightness(object):
132 |
133 | def __init__(self, var):
134 | self.var = var
135 |
136 | def __call__(self, img):
137 | gs = img.new().resize_as_(img).zero_()
138 | alpha = random.uniform(0, self.var)
139 | return img.lerp(gs, alpha)
140 |
141 |
142 | class Contrast(object):
143 |
144 | def __init__(self, var):
145 | self.var = var
146 |
147 | def __call__(self, img):
148 | gs = Grayscale()(img)
149 | gs.fill_(gs.mean())
150 | alpha = random.uniform(0, self.var)
151 | return img.lerp(gs, alpha)
152 |
153 |
154 | class RandomOrder(object):
155 | """ Composes several transforms together in random order.
156 | """
157 |
158 | def __init__(self, transforms):
159 | self.transforms = transforms
160 |
161 | def __call__(self, img):
162 | if self.transforms is None:
163 | return img
164 | order = torch.randperm(len(self.transforms))
165 | for i in order:
166 | img = self.transforms[i](img)
167 | return img
168 |
169 |
170 | class ColorJitter(RandomOrder):
171 |
172 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4):
173 | self.transforms = []
174 | if brightness != 0:
175 | self.transforms.append(Brightness(brightness))
176 | if contrast != 0:
177 | self.transforms.append(Contrast(contrast))
178 | if saturation != 0:
179 | self.transforms.append(Saturation(saturation))
180 |
--------------------------------------------------------------------------------
/dataloader/preprocess.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms as transforms
3 | import random
4 |
5 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406],
6 | 'std': [0.229, 0.224, 0.225]}
7 |
8 | #__imagenet_stats = {'mean': [0.5, 0.5, 0.5],
9 | # 'std': [0.5, 0.5, 0.5]}
10 |
11 | __imagenet_pca = {
12 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
13 | 'eigvec': torch.Tensor([
14 | [-0.5675, 0.7192, 0.4009],
15 | [-0.5808, -0.0045, -0.8140],
16 | [-0.5836, -0.6948, 0.4203],
17 | ])
18 | }
19 |
20 |
21 | def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats):
22 | t_list = [
23 | transforms.ToTensor(),
24 | transforms.Normalize(**normalize),
25 | ]
26 | #if scale_size != input_size:
27 | #t_list = [transforms.Scale((960,540))] + t_list
28 |
29 | return transforms.Compose(t_list)
30 |
31 |
32 | def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats):
33 | t_list = [
34 | transforms.RandomCrop(input_size),
35 | transforms.ToTensor(),
36 | transforms.Normalize(**normalize),
37 | ]
38 | if scale_size != input_size:
39 | t_list = [transforms.Scale(scale_size)] + t_list
40 |
41 | transforms.Compose(t_list)
42 |
43 |
44 | def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats):
45 | padding = int((scale_size - input_size) / 2)
46 | return transforms.Compose([
47 | transforms.RandomCrop(input_size, padding=padding),
48 | transforms.RandomHorizontalFlip(),
49 | transforms.ToTensor(),
50 | transforms.Normalize(**normalize),
51 | ])
52 |
53 |
54 | def inception_preproccess(input_size, normalize=__imagenet_stats):
55 | return transforms.Compose([
56 | transforms.RandomSizedCrop(input_size),
57 | transforms.RandomHorizontalFlip(),
58 | transforms.ToTensor(),
59 | transforms.Normalize(**normalize)
60 | ])
61 | def inception_color_preproccess(input_size, normalize=__imagenet_stats):
62 | return transforms.Compose([
63 | #transforms.RandomSizedCrop(input_size),
64 | #transforms.RandomHorizontalFlip(),
65 | transforms.ToTensor(),
66 | ColorJitter(
67 | brightness=0.4,
68 | contrast=0.4,
69 | saturation=0.4,
70 | ),
71 | Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']),
72 | transforms.Normalize(**normalize)
73 | ])
74 |
75 |
76 | def get_transform(name='imagenet', input_size=None,
77 | scale_size=None, normalize=None, augment=True):
78 | normalize = __imagenet_stats
79 | input_size = 256
80 | if augment:
81 | return inception_color_preproccess(input_size, normalize=normalize)
82 | else:
83 | return scale_crop(input_size=input_size,
84 | scale_size=scale_size, normalize=normalize)
85 |
86 |
87 |
88 |
89 | class Lighting(object):
90 | """Lighting noise(AlexNet - style PCA - based noise)"""
91 |
92 | def __init__(self, alphastd, eigval, eigvec):
93 | self.alphastd = alphastd
94 | self.eigval = eigval
95 | self.eigvec = eigvec
96 |
97 | def __call__(self, img):
98 | if self.alphastd == 0:
99 | return img
100 |
101 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
102 | rgb = self.eigvec.type_as(img).clone()\
103 | .mul(alpha.view(1, 3).expand(3, 3))\
104 | .mul(self.eigval.view(1, 3).expand(3, 3))\
105 | .sum(1).squeeze()
106 |
107 | return img.add(rgb.view(3, 1, 1).expand_as(img))
108 |
109 |
110 | class Grayscale(object):
111 |
112 | def __call__(self, img):
113 | gs = img.clone()
114 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2])
115 | gs[1].copy_(gs[0])
116 | gs[2].copy_(gs[0])
117 | return gs
118 |
119 |
120 | class Saturation(object):
121 |
122 | def __init__(self, var):
123 | self.var = var
124 |
125 | def __call__(self, img):
126 | gs = Grayscale()(img)
127 | alpha = random.uniform(0, self.var)
128 | return img.lerp(gs, alpha)
129 |
130 |
131 | class Brightness(object):
132 |
133 | def __init__(self, var):
134 | self.var = var
135 |
136 | def __call__(self, img):
137 | gs = img.new().resize_as_(img).zero_()
138 | alpha = random.uniform(0, self.var)
139 | return img.lerp(gs, alpha)
140 |
141 |
142 | class Contrast(object):
143 |
144 | def __init__(self, var):
145 | self.var = var
146 |
147 | def __call__(self, img):
148 | gs = Grayscale()(img)
149 | gs.fill_(gs.mean())
150 | alpha = random.uniform(0, self.var)
151 | return img.lerp(gs, alpha)
152 |
153 |
154 | class RandomOrder(object):
155 | """ Composes several transforms together in random order.
156 | """
157 |
158 | def __init__(self, transforms):
159 | self.transforms = transforms
160 |
161 | def __call__(self, img):
162 | if self.transforms is None:
163 | return img
164 | order = torch.randperm(len(self.transforms))
165 | for i in order:
166 | img = self.transforms[i](img)
167 | return img
168 |
169 |
170 | class ColorJitter(RandomOrder):
171 |
172 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4):
173 | self.transforms = []
174 | if brightness != 0:
175 | self.transforms.append(Brightness(brightness))
176 | if contrast != 0:
177 | self.transforms.append(Contrast(contrast))
178 | if saturation != 0:
179 | self.transforms.append(Saturation(saturation))
180 |
--------------------------------------------------------------------------------
/submission.py:
--------------------------------------------------------------------------------
1 | #coding=utf-8
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import torchvision.transforms as transforms
6 | import argparse
7 | np.set_printoptions(threshold=np.inf)
8 | import torch.nn.functional as F
9 | from PIL import Image
10 | import utils.logger as logger
11 | import time
12 | from models.LWANet import *
13 |
14 |
15 | parser = argparse.ArgumentParser(description='LWANet submission')
16 |
17 | parser = argparse.ArgumentParser(description='AnyNet with Flyingthings3d')
18 | parser.add_argument('--maxdisp', type=int, default=192, help='maxium disparity')
19 | parser.add_argument('--loss_weights', type=float, nargs='+', default=[1., 1.])
20 | parser.add_argument('--maxdisplist', type=int, nargs='+', default=[24, 3, 3])
21 | parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
22 | parser.add_argument('--with_cspn', type =bool, default= True, help='with cspn network or not')
23 | parser.add_argument('--datapath', default='/data6/wsgan/SenceFlow/train/', help='datapath')
24 | parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train')
25 | parser.add_argument('--train_bsize', type=int, default=8, help='batch size for training (default: 12)')
26 | parser.add_argument('--test_bsize', type=int, default=8, help='batch size for testing (default: 8)')
27 | parser.add_argument('--save_path', type=str, default='./results/kitti2015/benchmark', help='the path of saving checkpoints and log')
28 | parser.add_argument('--resume', type=str, default=None, help='resume path')
29 | parser.add_argument('--print_freq', type=int, default=400, help='print frequence')
30 |
31 | parser.add_argument('--model_types', type=str, default='original', help='model_types : LWANet_3D, mix, original')
32 | parser.add_argument('--conv_3d_types1', type=str, default='separate_only', help='model_types : 3D, P3D ')
33 | parser.add_argument('--conv_3d_types2', type=str, default='separate_only', help='model_types : 3D, P3D')
34 | parser.add_argument('--cost_volume', type=str, default='Difference', help='cost_volume type : "Concat" , "Difference" or "Distance_based" ')
35 | parser.add_argument('--train', type =bool, default=True, help='train or test ')
36 |
37 |
38 | parser.add_argument('--datapath2015', default='/data6/wsgan/KITTI/KITTI2015/testing/', help='datapath')
39 | parser.add_argument('--datapath2012', default='/data6/wsgan/KITTI/KITTI2012/testing/', help='datapath')
40 | parser.add_argument('--datatype', default='2015', help='finetune dataset: 2012, 2015')
41 |
42 | args = parser.parse_args()
43 |
44 |
45 |
46 | if args.datatype == '2015':
47 | from dataloader import KITTI_submission_loader as DA
48 |
49 | test_left_img, test_right_img = DA.dataloader2015(args.datapath2015)
50 |
51 | elif args.datatype == '2012':
52 |
53 | from dataloader import KITTI_submission_loader as DA
54 | test_left_img, test_right_img = DA.dataloader2012(args.datapath2012)
55 |
56 | else:
57 |
58 | AssertionError("None found datatype")
59 |
60 |
61 |
62 |
63 | log = logger.setup_logger(args.save_path + '/training.log')
64 | for key, value in sorted(vars(args).items()):
65 | log.info(str(key) + ': ' + str(value))
66 |
67 | if args.pretrained:
68 | if os.path.isfile(args.pretrained):
69 | checkpoint = torch.load(args.pretrained)
70 | model.load_state_dict(checkpoint['state_dict'], strict=False)
71 | log.info('=> loaded pretrained model {}'.format(args.pretrained))
72 | else:
73 | log.info('=> no pretrained model found at {}'.format(args.pretrained))
74 | log.info("=> Will start from scratch.")
75 |
76 |
77 | else:
78 | log.info('Not Resume')
79 |
80 |
81 |
82 | model = LWANet(args)
83 | if args.cuda:
84 | model = nn.DataParallel(model)
85 | model.cuda()
86 |
87 |
88 |
89 | def test(imgL,imgR):
90 |
91 | model.eval()
92 | if args.cuda:
93 | imgL = imgL.cuda()
94 | imgR = imgR.cuda()
95 |
96 | with torch.no_grad():
97 | disp, loss = model(imgL,imgR)
98 |
99 | disp = torch.squeeze(disp[-1])
100 | #print('disp size:', disp.shape)
101 | pred_disp = disp.data.cpu().numpy()
102 |
103 | return pred_disp
104 |
105 |
106 |
107 | def main():
108 | normal_mean_var = {'mean': [0.485, 0.456, 0.406],
109 | 'std': [0.229, 0.224, 0.225]}
110 | infer_transform = transforms.Compose([transforms.ToTensor(),
111 | transforms.Normalize(**normal_mean_var)])
112 |
113 | total_inference_time = 0
114 |
115 | for inx in range(len(test_left_img)):
116 |
117 | imgL_o = Image.open(test_left_img[inx]).convert('RGB')
118 | imgR_o = Image.open(test_right_img[inx]).convert('RGB')
119 |
120 |
121 | imgL = infer_transform(imgL_o)
122 | imgR = infer_transform(imgR_o)
123 |
124 |
125 | # pad to width and hight to 16 times
126 | if imgL.shape[1] % 16 != 0:
127 | times = imgL.shape[1]//16
128 | top_pad = (times+1)*16 -imgL.shape[1]
129 | else:
130 | top_pad = 0
131 |
132 | if imgL.shape[2] % 16 != 0:
133 | times = imgL.shape[2]//16
134 | right_pad = (times+1)*16-imgL.shape[2]
135 | else:
136 | right_pad = 0
137 |
138 | imgL = F.pad(imgL,(0,right_pad, top_pad,0)).unsqueeze(0)
139 | imgR = F.pad(imgR,(0,right_pad, top_pad,0)).unsqueeze(0)
140 |
141 | start_time = time.time()
142 | pred_disp = test(imgL,imgR)
143 |
144 | total_inference_time += time.time() - start_time
145 |
146 | if top_pad !=0 or right_pad != 0:
147 | img = pred_disp[top_pad:,:-right_pad]
148 | else:
149 | img = pred_disp
150 |
151 | img = (img*256).astype('uint16')
152 | img = Image.fromarray(img)
153 | print("inx:", inx)
154 | img.save(args.save_path + test_left_img[inx].split('/')[-1])
155 |
156 |
157 | log.info("mean inference time: %.3fs " % (total_inference_time/len(test_left_img)))
158 |
159 | log.info("finish {} images inference".format(len(test_left_img)))
160 |
161 |
162 |
163 | if __name__ == '__main__':
164 | main()
165 |
166 |
--------------------------------------------------------------------------------
/models/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class self_supervised_loss (nn.modules.Module):
6 | def __init__(self, n=1, SSIM_w=0.85, disp_gradient_w=1.0, lr_w=1.0):
7 | super(self_supervised_loss, self).__init__()
8 | self.SSIM_w = SSIM_w
9 | self.disp_gradient_w = disp_gradient_w
10 | self.lr_w = lr_w
11 | self.n = n
12 |
13 | def scale_pyramid(self, img):
14 | scaled_imgs = [img]
15 |
16 | return scaled_imgs
17 |
18 | def gradient_x(self, img):
19 | # Pad input to keep output size consistent
20 | img = F.pad(img, (0, 1, 0, 0), mode="replicate")
21 | gx = img[:, :, :, :-1] - img[:, :, :, 1:] # NCHW
22 | return gx
23 |
24 | def gradient_y(self, img):
25 | # Pad input to keep output size consistent
26 | img = F.pad(img, (0, 0, 0, 1), mode="replicate")
27 | gy = img[:, :, :-1, :] - img[:, :, 1:, :] # NCHW
28 | return gy
29 |
30 |
31 |
32 | def apply_disparity(slef, img, disp, cuda=True):
33 | '''
34 | img.shape = b, c, h, w
35 | disp.shape = b, h, w
36 | '''
37 | b, c, h, w = img.shape
38 | disp = disp.squeeze(1)
39 |
40 | if cuda == True:
41 | right_coor_x = (torch.arange(start=0, end=w, out=torch.cuda.FloatTensor())).repeat(b, h, 1)
42 | right_coor_y = (torch.arange(start=0, end=h, out=torch.cuda.FloatTensor())).repeat(b, w, 1).transpose(1, 2)
43 | else:
44 | right_coor_x = (torch.arange(start=0, end=w, out=torch.FloatTensor())).repeat(b, h, 1)
45 | right_coor_y = (torch.arange(start=0, end=h, out=torch.FloatTensor())).repeat(b, w, 1).transpose(1, 2)
46 | left_coor_x1 = right_coor_x + disp
47 | left_coor_norm1 = torch.stack((left_coor_x1 / (w - 1) * 2 - 1, right_coor_y / (h - 1) * 2 - 1), dim=1)
48 | ## backward warp
49 | warp_img = torch.nn.functional.grid_sample(img, left_coor_norm1.permute(0, 2, 3, 1))
50 |
51 | return warp_img
52 |
53 | def generate_image_left(self, img, disp):
54 | return self.apply_disparity(img, -disp)
55 |
56 | def generate_image_right(self, img, disp):
57 | return self.apply_disparity(img, disp)
58 |
59 | def SSIM(self, x, y):
60 | C1 = 0.01 ** 2
61 | C2 = 0.03 ** 2
62 |
63 | mu_x = nn.AvgPool2d(3, 1)(x)
64 | mu_y = nn.AvgPool2d(3, 1)(y)
65 | mu_x_mu_y = mu_x * mu_y
66 | mu_x_sq = mu_x.pow(2)
67 | mu_y_sq = mu_y.pow(2)
68 |
69 | sigma_x = nn.AvgPool2d(3, 1)(x * x) - mu_x_sq
70 | sigma_y = nn.AvgPool2d(3, 1)(y * y) - mu_y_sq
71 | sigma_xy = nn.AvgPool2d(3, 1)(x * y) - mu_x_mu_y
72 |
73 | SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2)
74 | SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2)
75 | SSIM = SSIM_n / SSIM_d
76 |
77 | return torch.clamp((1 - SSIM) / 2, 0, 1)
78 |
79 | def SSIM_WEIGHT(self, x, y):
80 | C1 = 0.01 ** 2
81 | C2 = 0.03 ** 2
82 |
83 | mu_x = nn.AvgPool2d(3, 1)(x)
84 | mu_y = nn.AvgPool2d(3, 1)(y)
85 | mu_x_mu_y = mu_x * mu_y
86 | mu_x_sq = mu_x.pow(2)
87 | mu_y_sq = mu_y.pow(2)
88 |
89 | sigma_x = nn.AvgPool2d(3, 1)(x * x) - mu_x_sq
90 | sigma_y = nn.AvgPool2d(3, 1)(y * y) - mu_y_sq
91 | sigma_xy = nn.AvgPool2d(3, 1)(x * y) - mu_x_mu_y
92 |
93 | SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2)
94 | SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2)
95 | SSIM = SSIM_n / SSIM_d
96 |
97 | return torch.clamp((SSIM) / 2, 0, 1)
98 |
99 | def disp_smoothness(self, disp, pyramid):
100 | disp_gradients_x = [self.gradient_x(d) for d in disp]
101 | disp_gradients_y = [self.gradient_y(d) for d in disp]
102 |
103 | image_gradients_x = [self.gradient_x(img) for img in pyramid]
104 | image_gradients_y = [self.gradient_y(img) for img in pyramid]
105 |
106 | weights_x = [torch.exp(-torch.mean(torch.abs(g), 1,
107 | keepdim=True)) for g in image_gradients_x]
108 | weights_y = [torch.exp(-torch.mean(torch.abs(g), 1,
109 | keepdim=True)) for g in image_gradients_y]
110 |
111 | smoothness_x = [disp_gradients_x[i] * weights_x[i]
112 | for i in range(self.n)]
113 | smoothness_y = [disp_gradients_y[i] * weights_y[i]
114 | for i in range(self.n)]
115 |
116 | return [torch.abs(smoothness_x[i]) + torch.abs(smoothness_y[i])
117 | for i in range(self.n)]
118 |
119 |
120 | def reconstruction_image_first_order_gradient(self, left_est, left_pyramid):
121 |
122 | RI_x = [self.gradient_x(d) for d in left_est]
123 | RI_y = [self.gradient_y(d) for d in left_est]
124 |
125 | OI_x = [self.gradient_x(d) for d in left_pyramid]
126 | OI_y = [self.gradient_y(d) for d in left_pyramid]
127 |
128 | fisrt_order_loss = [torch.mean(torch.abs(RI_x[i] - OI_x[i])) + torch.mean(torch.abs(RI_y[i] - OI_y[i]))
129 | for i in range(self.n)]
130 |
131 | return fisrt_order_loss
132 |
133 | def forward(self, input, target):
134 | """
135 | Args:
136 | input [disp1, disp2]
137 | target [left, right]
138 |
139 | Return:
140 | (float): The loss
141 | """
142 |
143 | left, right = target
144 |
145 |
146 | disp_left_est = [input[:, 0, :, :].unsqueeze(1) ]
147 |
148 | left_pyramid = self.scale_pyramid(left)
149 | right_pyramid = self.scale_pyramid(right)
150 |
151 | # Generate images
152 | left_est = [self.generate_image_left(right_pyramid[0],
153 | disp_left_est[0]) ]
154 |
155 |
156 | # Disparities smoothness
157 | disp_left_smoothness = self.disp_smoothness(disp_left_est, left_pyramid)
158 | l1_left = [torch.mean(torch.abs(left_est[0] - left_pyramid[0]))]
159 | ssim_left = [torch.mean(self.SSIM(left_est[0], left_pyramid[0])) ]
160 | image_loss_left = [self.SSIM_w * ssim_left[0] + (1 - self.SSIM_w) * (l1_left[0] )]
161 |
162 | image_loss = sum(image_loss_left)
163 |
164 | # Disparities smoothness
165 | disp_left_loss = [torch.mean(torch.abs(disp_left_smoothness[0])) ]
166 | disp_gradient_loss = sum(disp_left_loss)
167 |
168 | loss = image_loss + self.disp_gradient_w * disp_gradient_loss
169 |
170 | return loss
171 |
--------------------------------------------------------------------------------
/models/Aggregation_submodules.py:
--------------------------------------------------------------------------------
1 | #coding=utf-8
2 | from __future__ import print_function
3 | import torch.nn as nn
4 | import math
5 | import torch
6 | import torch.nn.functional as F
7 |
8 |
9 | def activation_function(types = "ELU"): # ELU or Relu
10 |
11 |
12 | if types == "ELU":
13 |
14 | return nn.Sequential(nn.ELU(inplace=True))
15 |
16 | elif types == "Mish":
17 |
18 | nn.Sequential(Mish())
19 |
20 | elif types == "Relu":
21 |
22 | return nn.Sequential(nn.ReLU(inplace=True))
23 |
24 | else:
25 |
26 | AssertionError("please define the activate function types")
27 |
28 |
29 |
30 |
31 | class Mish(nn.Module):
32 | '''
33 | Applies the mish function element-wise:
34 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
35 | Shape:
36 | - Input: (N, *) where * means, any number of additional
37 | dimensions
38 | - Output: (N, *), same shape as the input
39 | Examples:
40 |
41 | '''
42 | def __init__(self):
43 | '''
44 | Init method.
45 | '''
46 | super().__init__()
47 |
48 | def forward(self, input):
49 | '''
50 | Forward pass of the function.
51 | '''
52 | return input * torch.tanh(F.softplus(input))
53 |
54 |
55 |
56 | # cost aggregation submodule
57 |
58 | def conv_3d(in_planes, out_planes, kernel_size, stride, pad, conv_3d_types="3D"):
59 |
60 |
61 | if conv_3d_types == "3D":
62 |
63 | return nn.Sequential(
64 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, padding=pad, stride=stride, bias=False)
65 | )
66 |
67 |
68 | elif conv_3d_types == "P3D": # 3*3*3 to 1*3*3 + 3*1*1
69 |
70 | return nn.Sequential(
71 |
72 | nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride, padding=(0, 1, 1), bias=False),
73 | nn.ReLU(inplace=True),
74 | nn.Conv3d(out_planes, out_planes, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False),
75 |
76 | )
77 |
78 |
79 | else:
80 |
81 | AssertionError("please define conv_3d_types")
82 |
83 |
84 |
85 |
86 | def convbn_3d(in_planes, out_planes, kernel_size, stride, pad, conv_3d_types="3D"):
87 |
88 |
89 | if conv_3d_types == "3D":
90 |
91 | return nn.Sequential(
92 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, padding=pad, stride=stride, bias=False),
93 | nn.BatchNorm3d(out_planes))
94 |
95 |
96 | elif conv_3d_types == "P3D": # 3*3*3 to 1*3*3 + 3*1*1
97 |
98 | return nn.Sequential(
99 |
100 | nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride, padding=(0, 1, 1), bias=False),
101 | nn.ReLU(inplace=True),
102 | nn.Conv3d(out_planes, out_planes, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False),
103 |
104 | nn.BatchNorm3d(out_planes))
105 |
106 |
107 | else:
108 |
109 | AssertionError("please define conv_3d_types")
110 |
111 |
112 |
113 | def convTranspose3d(in_planes, out_planes, kernel_size, stride, padding=1, conv_3d_types="P3D"):
114 |
115 | if conv_3d_types == '3D':
116 | return nn.Sequential(
117 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size, padding = padding, output_padding=1, stride=stride, bias=False),
118 | nn.BatchNorm3d(out_planes))
119 |
120 |
121 | elif conv_3d_types == "P3D":
122 |
123 | return nn.Sequential(
124 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size, padding=padding, output_padding=1, stride=stride,
125 | bias=False),
126 | nn.BatchNorm3d(out_planes))
127 |
128 |
129 | else:
130 | AssertionError("please define conv_3d_types")
131 |
132 |
133 |
134 |
135 | class LWANet_Aggregation(nn.Module): # base on PSMNet basic
136 | def __init__(self, input_planes=8, planes=16, maxdisp=192, conv_3d_types1 = "P3D", conv_3d_types2 = "P3D", activation_types2 = "ELU"):
137 | super(LWANet_Aggregation, self).__init__()
138 | self.maxdisp = maxdisp
139 |
140 | self.pre_3D = nn.Sequential(
141 | convbn_3d(input_planes, planes, 3, 1, 1, conv_3d_types = conv_3d_types1),
142 | activation_function(types = activation_types2),
143 | convbn_3d(planes, planes, 3, 2, 1, conv_3d_types = conv_3d_types1),
144 | activation_function(types = activation_types2)
145 | )
146 |
147 | self.middle_3D = nn.Sequential(
148 |
149 | convbn_3d(planes, planes*2, 3, 1, 1, conv_3d_types = conv_3d_types2),
150 | activation_function(types = activation_types2),
151 | convbn_3d(planes*2, planes*4, 3, 1, 1, conv_3d_types = conv_3d_types2),
152 | activation_function(types = activation_types2),
153 | convbn_3d(planes * 4, planes * 4, 3, 1, 1, conv_3d_types=conv_3d_types2),
154 | activation_function(types=activation_types2),
155 | convbn_3d(planes * 4, planes * 2, 3, 1, 1, conv_3d_types=conv_3d_types2),
156 | activation_function(types=activation_types2),
157 | convTranspose3d(planes * 2, planes * 2, kernel_size=3, stride=2, conv_3d_types=conv_3d_types2),
158 | activation_function(types=activation_types2)
159 | )
160 |
161 | self.post_3D = nn.Sequential(
162 | convbn_3d(planes*2, planes, 3, 1, 1, conv_3d_types = conv_3d_types1),
163 | activation_function(types = activation_types2),
164 | conv_3d(planes, 1, kernel_size=3, pad=1, stride=1, conv_3d_types = conv_3d_types1)
165 | )
166 |
167 |
168 | for m in self.modules():
169 | if isinstance(m, nn.Conv2d):
170 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
171 | m.weight.data.normal_(0, math.sqrt(2. / n))
172 | elif isinstance(m, nn.Conv3d):
173 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
174 | m.weight.data.normal_(0, math.sqrt(2. / n))
175 | elif isinstance(m, nn.BatchNorm2d):
176 | m.weight.data.fill_(1)
177 | m.bias.data.zero_()
178 | elif isinstance(m, nn.BatchNorm3d):
179 | m.weight.data.fill_(1)
180 | m.bias.data.zero_()
181 | elif isinstance(m, nn.Linear):
182 | m.bias.data.zero_()
183 |
184 | def forward(self, cost):
185 |
186 | cost = self.pre_3D(cost)
187 | cost = self.middle_3D(cost)
188 | cost = self.post_3D(cost)
189 |
190 |
191 | return cost
192 |
193 |
194 |
195 |
196 |
--------------------------------------------------------------------------------
/models/LWANet.py:
--------------------------------------------------------------------------------
1 | #coding=utf-8
2 | from __future__ import print_function
3 | import torch
4 | import torch.nn as nn
5 | import torch.utils.data
6 | from torch.autograd import Variable
7 | import torch.nn.functional as F
8 | import math
9 | from .cspn import Affinity_Propagate
10 | from .feature_extraction import F1, F2, F3, F2_UP, F3_UP , F1_UP
11 | from .Aggregation_submodules import LWANet_Aggregation
12 | from .cost import _build_cost_volume
13 | from .loss import self_supervised_loss
14 |
15 |
16 | class LWANet(nn.Module):
17 | def __init__(self, args):
18 | super(LWANet, self).__init__()
19 |
20 | #self.init_channels = args.init_channels
21 | self.maxdisplist = args.maxdisplist
22 | self.with_cspn = args.with_cspn
23 | self.model_types =args.model_types # "LWANet: 3D orP3D
24 | self.conv_3d_types1 = args.conv_3d_types1
25 | self.conv_3d_types2 = args.conv_3d_types2
26 | self.cost_volume = args.cost_volume
27 | self.maxdisp = args.maxdisp
28 |
29 |
30 | self.F1 = F1()
31 | self.F2 = F2()
32 | self.F3 = F3()
33 |
34 | self.F1_CSPN = F1()
35 | self.F2_CSPN = F2()
36 | self.F3_CSPN = F3()
37 |
38 | if self.cost_volume =="Distance_based":
39 |
40 | self.volume_postprocess = LWANet_Aggregation( input_planes=1, planes=8,
41 | conv_3d_types1 = self.conv_3d_types1,
42 | conv_3d_types2 = self.conv_3d_types2)
43 |
44 | elif self.cost_volume =="Difference":
45 | self.volume_postprocess = LWANet_Aggregation(input_planes=16, planes=12,
46 | conv_3d_types1=self.conv_3d_types1,
47 | conv_3d_types2=self.conv_3d_types2)
48 |
49 | elif self.cost_volume =="Concat":
50 | self.volume_postprocess = LWANet_Aggregation(input_planes=32, planes=12,
51 | conv_3d_types1=self.conv_3d_types1,
52 | conv_3d_types2=self.conv_3d_types2)
53 |
54 | if self.with_cspn:
55 |
56 | self.F2_UP = F2_UP()
57 | self.F3_UP = F3_UP()
58 | self.F1_UP = F1_UP()
59 |
60 | cspn_config_default = {'step':4, 'kernel': 3, 'norm_type': '8sum'}
61 | self.post_process_layer = [self._make_post_process_layer(cspn_config_default)]
62 | self.post_process_layer = nn.ModuleList(self.post_process_layer)
63 |
64 | self.self_supervised_loss = self_supervised_loss(n=1, SSIM_w=0.85, disp_gradient_w=0.1, lr_w=1)
65 |
66 | for m in self.modules():
67 | if isinstance(m, nn.Conv2d):
68 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
69 | m.weight.data.normal_(0, math.sqrt(2. / n))
70 | elif isinstance(m, nn.Conv3d):
71 | n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels
72 | m.weight.data.normal_(0, math.sqrt(2. / n))
73 | elif isinstance(m, nn.BatchNorm2d):
74 | m.weight.data.fill_(1)
75 | m.bias.data.zero_()
76 | elif isinstance(m, nn.BatchNorm3d):
77 | m.weight.data.fill_(1)
78 | m.bias.data.zero_()
79 | elif isinstance(m, nn.Linear):
80 | m.bias.data.zero_()
81 |
82 |
83 | def _make_post_process_layer(self, cspn_config=None):
84 | return Affinity_Propagate(cspn_config['step'],
85 | cspn_config['kernel'],
86 | norm_type=cspn_config['norm_type'])
87 |
88 | def forward(self, left, right):
89 |
90 |
91 | img_size = left.size()
92 |
93 | feats_l_F1 = self.F1(left)
94 |
95 | feats_l_F2 = self.F2(feats_l_F1)
96 |
97 | feats_l_F3 = self.F3(feats_l_F2)
98 |
99 |
100 | feats_l = feats_l_F3
101 |
102 |
103 | feats_r_F1 = self.F1(right)
104 |
105 | feats_r_F2 = self.F2(feats_r_F1)
106 |
107 | feats_r_F3 = self.F3(feats_r_F2)
108 |
109 | feats_r = feats_r_F3
110 |
111 |
112 | pred = []
113 |
114 | cost = _build_cost_volume(self.cost_volume, feats_l, feats_r, self.maxdisp)
115 |
116 | cost = self.volume_postprocess(cost).squeeze(1)
117 |
118 |
119 | pred_low_res_left = disparityregression2(0, self.maxdisplist[0])(F.softmax(-cost, dim=1))
120 |
121 | pred_low_res = pred_low_res_left * img_size[2] / pred_low_res_left.size(2)
122 |
123 | disp_up = F.upsample(pred_low_res, (img_size[2], img_size[3]), mode='bilinear')
124 |
125 | pred.append(disp_up)
126 |
127 |
128 | if self.with_cspn:
129 |
130 | feats_l_F1_CSPN = self.F1_CSPN(left)
131 |
132 | feats_l_F2_CSPN = self.F2_CSPN(feats_l_F1_CSPN)
133 |
134 | feats_l_F3_CSPN = self.F3_CSPN(feats_l_F2_CSPN)
135 |
136 |
137 |
138 | F3_UP = torch.cat((self.F3_UP(feats_l_F3_CSPN), feats_l_F2_CSPN), 1)
139 |
140 | F2_UP = torch.cat((self.F2_UP(F3_UP), feats_l_F1_CSPN), 1)
141 |
142 | F1_UP = self.F1_UP(F2_UP)
143 |
144 | x = self.post_process_layer[0](F1_UP, disp_up)
145 |
146 | pred.append(x)
147 |
148 | loss = []
149 |
150 |
151 | if self.train:
152 | for outputs in pred:
153 | loss.append(self.self_supervised_loss(outputs, [left, right]))
154 |
155 | else:
156 | loss = [0]
157 |
158 | pred = [torch.squeeze(pred, 1) for pred in pred]
159 |
160 | return pred, loss
161 |
162 |
163 |
164 | class disparityregression2(nn.Module):
165 | def __init__(self, start, end, stride=1):
166 | super(disparityregression2, self).__init__()
167 | self.disp = Variable(torch.arange(start*stride, end*stride, stride).view(1, -1, 1, 1).cuda(), requires_grad=False)
168 |
169 | def forward(self, x):
170 | disp = self.disp.repeat(x.size()[0], 1, x.size()[2], x.size()[3])
171 | disp = disp.float()
172 |
173 | out = torch.sum(x*disp, 1, keepdim=True)
174 | return out
175 |
176 |
177 |
178 | class L1Loss(object):
179 | def __call__(self, input, target):
180 | return torch.abs(input - target).mean()
181 |
182 |
183 |
184 | def apply_disparity(img, disp):
185 |
186 | batch_size, _, height, width = img.size()
187 |
188 | # Original coordinates of pixels
189 | x_base = torch.linspace(0, 1, width).repeat(batch_size,
190 | height, 1).type_as(img)
191 | y_base = torch.linspace(0, 1, height).repeat(batch_size,
192 | width, 1).transpose(1, 2).type_as(img)
193 |
194 | # Apply shift in X direction
195 | x_shifts = disp[:, 0, :, :] # Disparity is passed in NCHW format with 1 channel
196 | flow_field = torch.stack((x_base + x_shifts, y_base), dim=3)
197 | # In grid_sample coordinates are assumed to be between -1 and 1
198 | output = F.grid_sample(img, 2*flow_field - 1, mode='bilinear',
199 | padding_mode='zeros')
200 |
201 | return output
202 |
203 |
204 |
205 | def generate_image_left( img, disp):
206 | return apply_disparity(img, -disp)
207 |
208 |
209 |
--------------------------------------------------------------------------------
/utils/flops_hook.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.modules.conv import _ConvNd
4 | multiply_adds = 1
5 |
6 |
7 | def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor):
8 | x = x[0]
9 |
10 | kernel_ops = m.weight.size()[2:].numel() # Kw x Kh
11 |
12 | # N x Cout x H x W x (Cin x Kw x Kh + bias)
13 | total_ops = y.nelement() * (
14 | m.in_channels // m.groups * kernel_ops)
15 |
16 | m.total_ops += torch.Tensor([int(total_ops)])
17 |
18 |
19 |
20 |
21 | def count_bn(m, x, y):
22 | x = x[0]
23 |
24 | nelements = x.numel()
25 | if not m.training:
26 | # subtract, divide, gamma, beta
27 | total_ops = 2 * nelements
28 |
29 | m.total_ops += torch.Tensor([int(total_ops)])
30 |
31 |
32 |
33 |
34 | def count_relu(m, x, y):
35 | x = x[0]
36 |
37 | nelements = x.numel()
38 |
39 | m.total_ops += torch.Tensor([int(nelements)])
40 |
41 |
42 | def count_softmax(m, x, y):
43 | x = x[0]
44 |
45 | batch_size, nfeatures = x.size()
46 |
47 | total_exp = nfeatures
48 | total_add = nfeatures - 1
49 | total_div = nfeatures
50 | total_ops = batch_size * (total_exp + total_add + total_div)
51 |
52 | m.total_ops += torch.Tensor([int(total_ops)])
53 |
54 |
55 | def count_avgpool(m, x, y):
56 | # total_add = torch.prod(torch.Tensor([m.kernel_size]))
57 | # total_div = 1
58 | # kernel_ops = total_add + total_div
59 | kernel_ops = 1
60 | num_elements = y.numel()
61 | total_ops = kernel_ops * num_elements
62 |
63 | m.total_ops += torch.Tensor([int(total_ops)])
64 |
65 |
66 | def count_adap_avgpool(m, x, y):
67 | kernel = torch.Tensor(
68 | [*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze()
69 | total_add = torch.prod(kernel)
70 | total_div = 1
71 | kernel_ops = total_add + total_div
72 | num_elements = y.numel()
73 | total_ops = kernel_ops * num_elements
74 |
75 | m.total_ops += torch.Tensor([int(total_ops)])
76 |
77 |
78 | # TODO: verify the accuracy
79 | def count_upsample(m, x, y):
80 | if m.mode not in ("nearest", "linear", "bilinear", "bicubic",): # "trilinear"
81 | logger.warning(
82 | "mode %s is not implemented yet, take it a zero op" % m.mode)
83 | return zero_ops(m, x, y)
84 |
85 | if m.mode == "nearest":
86 | return zero_ops(m, x, y)
87 |
88 | x = x[0]
89 | if m.mode == "linear":
90 | total_ops = y.nelement() * 5 # 2 muls + 3 add
91 | elif m.mode == "bilinear":
92 | # https://en.wikipedia.org/wiki/Bilinear_interpolation
93 | total_ops = y.nelement() * 11 # 6 muls + 5 adds
94 | elif m.mode == "bicubic":
95 | # https://en.wikipedia.org/wiki/Bicubic_interpolation
96 | # Product matrix [4x4] x [4x4] x [4x4]
97 | ops_solve_A = 224 # 128 muls + 96 adds
98 | ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds
99 | total_ops = y.nelement() * (ops_solve_A + ops_solve_p)
100 | elif m.mode == "trilinear":
101 | # https://en.wikipedia.org/wiki/Trilinear_interpolation
102 | # can viewed as 2 bilinear + 1 linear
103 | total_ops = y.nelement() * (13 * 2 + 5)
104 |
105 | m.total_ops += torch.Tensor([int(total_ops)])
106 |
107 |
108 | def count_linear(m, x, y):
109 | # per output element
110 | total_mul = m.in_features
111 | total_add = m.in_features - 1
112 | total_add += 1 if m.bias is not None else 0
113 | num_elements = y.numel()
114 | total_ops = (total_mul + total_add) * num_elements
115 |
116 | m.total_ops += torch.Tensor([int(total_ops)])
117 |
118 |
119 | def zero_ops(m, x, y):
120 | m.total_ops += torch.Tensor([int(0)])
121 |
122 |
123 |
124 | register_hooks = {
125 | nn.Conv1d: count_convNd,
126 | nn.Conv2d: count_convNd,
127 | nn.Conv3d: count_convNd,
128 | nn.ConvTranspose1d: count_convNd,
129 | nn.ConvTranspose2d: count_convNd,
130 | nn.ConvTranspose3d: count_convNd,
131 |
132 | nn.BatchNorm1d: count_bn,
133 | nn.BatchNorm2d: count_bn,
134 | nn.BatchNorm3d: count_bn,
135 |
136 | nn.ReLU: zero_ops,
137 | nn.ReLU6: zero_ops,
138 | nn.LeakyReLU: count_relu,
139 |
140 | nn.MaxPool1d: zero_ops,
141 | nn.MaxPool2d: zero_ops,
142 | nn.MaxPool3d: zero_ops,
143 | nn.AdaptiveMaxPool1d: zero_ops,
144 | nn.AdaptiveMaxPool2d: zero_ops,
145 | nn.AdaptiveMaxPool3d: zero_ops,
146 |
147 | nn.AvgPool1d: count_avgpool,
148 | nn.AvgPool2d: count_avgpool,
149 | nn.AvgPool3d: count_avgpool,
150 | nn.AdaptiveAvgPool1d: count_adap_avgpool,
151 | nn.AdaptiveAvgPool2d: count_adap_avgpool,
152 | nn.AdaptiveAvgPool3d: count_adap_avgpool,
153 |
154 | nn.Linear: count_linear,
155 | nn.Dropout: zero_ops,
156 |
157 | nn.Upsample: count_upsample,
158 | nn.UpsamplingBilinear2d: count_upsample,
159 | nn.UpsamplingNearest2d: count_upsample
160 | }
161 |
162 |
163 |
164 |
165 |
166 | def profile(model, inputs, custom_ops=None, verbose=True):
167 | handler_collection = []
168 | if custom_ops is None:
169 | custom_ops = {}
170 |
171 | def add_hooks(m):
172 | if len(list(m.children())) > 0:
173 | return
174 |
175 | # if hasattr(m, "total_ops") or hasattr(m, "total_params"):
176 | # raise Warning("Either .total_ops or .total_params is already defined in %s.\n"
177 | # "Be careful, it might change your code's behavior." % str(m))
178 |
179 | m.register_buffer('total_ops', torch.zeros(1))
180 | m.register_buffer('total_params', torch.zeros(1))
181 |
182 | for p in m.parameters():
183 | m.total_params += torch.Tensor([p.numel()])
184 |
185 | m_type = type(m)
186 | fn = None
187 | if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite.
188 | fn = custom_ops[m_type]
189 | elif m_type in register_hooks:
190 | fn = register_hooks[m_type]
191 |
192 | if fn is None:
193 | if verbose:
194 | print("THOP has not implemented counting method for", m)
195 | else:
196 | if verbose:
197 | print("Register FLOP counter for module %s" % str(m))
198 | handler = m.register_forward_hook(fn)
199 | handler_collection.append(handler)
200 |
201 | # original_device = model.parameters().__next__().device
202 | training = model.training
203 |
204 | model.eval()
205 | model.apply(add_hooks)
206 |
207 | with torch.no_grad():
208 | model(*inputs)
209 |
210 | total_ops = 0
211 | total_params = 0
212 | _temp = []
213 | for m in model.modules():
214 | if len(list(m.children())) > 0: # skip for non-leaf module
215 | continue
216 | total_ops += m.total_ops
217 | total_params += m.total_params
218 | _temp.append(m.total_ops.item())
219 |
220 | total_ops = total_ops.item()
221 | total_params = total_params.item()
222 |
223 | # reset model to original status
224 | model.train(training)
225 | for handler in handler_collection:
226 | handler.remove()
227 |
228 | return clever_format(total_ops), clever_format(total_params)
229 |
230 |
231 | def clever_format(num):
232 | # if num > 1e12:
233 | # return "%.2f" % (num / 1e12) + "T"
234 | if num > 1e9:
235 | return "%.2f" % (num / 1e9) + "G"
236 | if num > 1e6:
237 | return "%.2f" % (num / 1e6) + "M"
238 | if num > 1e3:
239 | return "%.2f" % (num / 1e3) + "K"
240 |
--------------------------------------------------------------------------------
/models/cost.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 |
5 |
6 | def _build_volume_2d_anynet(feat_l, feat_r, maxdisp, stride=1):
7 |
8 | assert maxdisp % stride == 0 # Assume maxdisp is multiple of stride
9 | cost = torch.zeros((feat_l.size()[0], maxdisp//stride, feat_l.size()[2], feat_l.size()[3]), device='cuda')
10 | for i in range(0, maxdisp, stride):
11 | cost[:, i// stride, :, :i] = feat_l[:, :, :, :i].abs().sum(1)
12 |
13 | if i > 0:
14 | cost[:, i // stride, :, i:] = torch.norm(feat_l[:, :, :, i:] - feat_r[:, :, :, :-i], 1, 1)
15 | else:
16 | cost[:, i // stride, :, i:] = torch.norm(feat_l[:, :, :, :] - feat_r[:, :, :, :], 1, 1)
17 |
18 | return cost.contiguous()
19 |
20 |
21 | def _build_volume_2d3_anynet( feat_l, feat_r, disp, maxdisp=3, stride=1):
22 | size = feat_l.size()
23 | batch_disp = disp[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, 1, size[-2], size[-1])
24 | batch_shift = torch.arange(-maxdisp + 1, maxdisp, device='cuda').repeat(size[0])[:, None, None, None] * stride
25 | batch_disp = batch_disp - batch_shift.float()
26 | batch_feat_l = feat_l[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, size[-3], size[-2], size[-1])
27 | batch_feat_r = feat_r[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, size[-3], size[-2], size[-1])
28 |
29 | cost = torch.norm(batch_feat_l - warp(batch_feat_r, batch_disp), 1, 1)
30 |
31 | cost = cost.view(size[0], -1, size[2], size[3])
32 |
33 | return cost.contiguous()
34 |
35 |
36 | def _build_volume_2d_psmnet( refimg_fea, targetimg_fea, maxdisp):
37 | cost = Variable(
38 | torch.FloatTensor(refimg_fea.size()[0], refimg_fea.size()[1] * 2, maxdisp , refimg_fea.size()[2],
39 | refimg_fea.size()[3]).zero_()).cuda()
40 |
41 | for i in range(maxdisp):
42 | if i > 0:
43 |
44 | cost[:, :refimg_fea.size()[1], i, :, i:] = refimg_fea[:, :, :, i:]
45 | cost[:, refimg_fea.size()[1]:, i, :, i:] = targetimg_fea[:, :, :, :-i]
46 | else:
47 | cost[:, :refimg_fea.size()[1], i, :, :] = refimg_fea
48 | cost[:, refimg_fea.size()[1]:, i, :, :] = targetimg_fea
49 |
50 | return cost.contiguous()
51 |
52 |
53 |
54 |
55 | def _build_volume_2d3_psmnet(feat_l, feat_r, disp, maxdisp=3, stride=1):
56 | size = feat_l.size()
57 |
58 | batch_disp = disp[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, 1, size[-2], size[-1])
59 | batch_shift = torch.arange(-maxdisp + 1, maxdisp, device='cuda').repeat(size[0])[:, None, None, None] * stride
60 | batch_disp = batch_disp - batch_shift.float()
61 | batch_feat_l = feat_l[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, size[-3], size[-2], size[-1])
62 | batch_feat_r = feat_r[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, size[-3], size[-2],
63 | size[-1])
64 | #cost = batch_feat_l - warp(batch_feat_r, batch_disp)
65 | cost = torch.cat((batch_feat_l , warp(batch_feat_r, batch_disp)), 1).contiguous()
66 |
67 | cost = cost.view(size[0], size[1]*2, -1, size[2], size[3])
68 | # print("cost size", cost.shape)
69 |
70 | return cost.contiguous()
71 |
72 |
73 |
74 |
75 | def _build_volume_2d_aanet(refimg_fea, targetimg_fea, maxdisp):
76 |
77 |
78 | b, c, h, w = refimg_fea.size()
79 | cost_volume = refimg_fea.new_zeros(b, maxdisp, h, w)
80 |
81 | for i in range(maxdisp):
82 | if i > 0:
83 | cost_volume[:, i, :, i:] = (refimg_fea[:, :, :, i:] *
84 | targetimg_fea[:, :, :, :-i]).mean(dim=1)
85 | else:
86 | cost_volume[:, i, :, :] = (refimg_fea * targetimg_fea).mean(dim=1)
87 |
88 | return cost_volume.contiguous()
89 |
90 |
91 |
92 |
93 |
94 | def _build_volume_2d_difference(feat_l, feat_r, maxdisp):
95 |
96 | b, c, h, w = feat_l.size()
97 |
98 |
99 | cost_volume = feat_l.new_zeros(b, c, maxdisp, h, w)
100 |
101 | for i in range(maxdisp):
102 | if i > 0:
103 | cost_volume[:, :, i, :, i:] = feat_l[:, :, :, i:] - feat_r[:, :, :, :-i]
104 | else:
105 | cost_volume[:, :, i, :, :] = feat_l - feat_r
106 |
107 | return cost_volume
108 |
109 |
110 |
111 |
112 |
113 | def _build_volume_2d3_difference( feat_l, feat_r, disp, maxdisp=3, stride=1):
114 | size = feat_l.size()
115 | batch_disp = disp[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, 1, size[-2], size[-1])
116 | batch_shift = torch.arange(-maxdisp + 1, maxdisp, device='cuda').repeat(size[0])[:, None, None, None] * stride
117 | batch_disp = batch_disp - batch_shift.float()
118 | batch_feat_l = feat_l[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, size[-3], size[-2], size[-1])
119 | batch_feat_r = feat_r[:, None, :, :, :].repeat(1, maxdisp * 2 - 1, 1, 1, 1).view(-1, size[-3], size[-2], size[-1])
120 |
121 | # cost = torch.norm(batch_feat_l - warp(batch_feat_r, batch_disp), 1, 1)
122 |
123 | cost = batch_feat_l - warp(batch_feat_r, batch_disp)
124 |
125 | cost = cost.view(size[0], size[1], -1, size[2], size[3])
126 |
127 | return cost.contiguous()
128 |
129 |
130 |
131 |
132 | def warp(x, disp):
133 | """
134 | warp an image/tensor (im2) back to im1, according to the optical flow
135 | x: [B, C, H, W] (im2)
136 | flo: [B, 2, H, W] flow
137 | """
138 | B, C, H, W = x.size()
139 | # mesh grid
140 | xx = torch.arange(0, W, device='cuda').view(1, -1).repeat(H, 1)
141 | yy = torch.arange(0, H, device='cuda').view(-1, 1).repeat(1, W)
142 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
143 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
144 | vgrid = torch.cat((xx, yy), 1).float()
145 |
146 | # vgrid = Variable(grid)
147 | vgrid[:,:1,:,:] = vgrid[:,:1,:,:] - disp
148 |
149 | # scale grid to [-1,1]
150 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
151 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
152 |
153 | vgrid = vgrid.permute(0, 2, 3, 1)
154 | #output = nn.functional.grid_sample(x, vgrid, align_corners=True )
155 | output = nn.functional.grid_sample(x, vgrid)
156 | return output
157 |
158 |
159 |
160 | def _build_cost_volume(cost_volume_type, refimg_fea, targetimg_fea, maxdisp):
161 | if cost_volume_type == "Concat":
162 |
163 | cost = _build_volume_2d_psmnet(refimg_fea, targetimg_fea, maxdisp=maxdisp // 8)
164 |
165 |
166 | elif cost_volume_type == "Distance_based":
167 |
168 | cost = _build_volume_2d_anynet(refimg_fea, targetimg_fea, maxdisp // 8, stride=1)
169 | cost = torch.unsqueeze(cost, 1)
170 |
171 |
172 | elif cost_volume_type == "Difference":
173 | #print("build difference")
174 | cost = _build_volume_2d_difference(refimg_fea, targetimg_fea, maxdisp // 8)
175 | #print("cost size:", cost.shape)
176 |
177 |
178 | else:
179 | AssertionError("please define cost volume types")
180 |
181 | return cost
182 |
183 |
184 |
185 |
186 |
187 | def _build_redidual_cost_volume(cost_volume_type, L2, R2, wflow, maxdisp):
188 | if cost_volume_type == "Concat":
189 |
190 | cost_residual = _build_volume_2d3_psmnet(L2, R2, wflow, maxdisp)
191 |
192 | elif cost_volume_type == "Distance_based":
193 | cost_residual = _build_volume_2d3_anynet(L2, R2, wflow, maxdisp)
194 | cost_residual = torch.unsqueeze(cost_residual, 1)
195 |
196 |
197 | elif cost_volume_type == "Difference":
198 | cost_residual = _build_volume_2d3_difference(L2, R2, wflow, maxdisp)
199 | # cost_residual = torch.unsqueeze(cost_residual, 1)
200 |
201 | else:
202 | AssertionError("please define cost volume types")
203 |
204 | return cost_residual
205 |
--------------------------------------------------------------------------------
/models/cspn.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Xinjing Cheng & Peng Wang
3 |
4 | """
5 |
6 | import torch.nn as nn
7 | import math
8 | import torch.utils.model_zoo as model_zoo
9 | import torch
10 | from torch.autograd import Variable
11 | import torch.nn.functional as F
12 |
13 |
14 | class Affinity_Propagate(nn.Module):
15 |
16 | def __init__(self,
17 | prop_time,
18 | prop_kernel,
19 | norm_type='8sum'):
20 | """
21 |
22 | Inputs:
23 | prop_time: how many steps for CSPN to perform
24 | prop_kernel: the size of kernel (current only support 3x3)
25 | way to normalize affinity
26 | '8sum': normalize using 8 surrounding neighborhood
27 | '8sum_abs': normalization enforcing affinity to be positive
28 | This will lead the center affinity to be 0
29 | """
30 | super(Affinity_Propagate, self).__init__()
31 | self.prop_time = prop_time
32 | self.prop_kernel = prop_kernel
33 | assert prop_kernel == 3, 'this version only support 8 (3x3 - 1) neighborhood'
34 |
35 | self.norm_type = norm_type
36 | assert norm_type in ['8sum', '8sum_abs']
37 |
38 | self.in_feature = 1
39 | self.out_feature = 1
40 |
41 | self.sum_conv = nn.Conv3d(in_channels=8,
42 | out_channels=1,
43 | kernel_size=(1, 1, 1),
44 | stride=1,
45 | padding=0,
46 | bias=False)
47 | weight = torch.ones(1, 8, 1, 1, 1).cuda()
48 | self.sum_conv.weight = nn.Parameter(weight)
49 | for param in self.sum_conv.parameters():
50 | param.requires_grad = False
51 |
52 |
53 | def forward(self, guidance, blur_depth):
54 |
55 | # self.sum_conv = nn.Conv3d(in_channels=8,
56 | # out_channels=1,
57 | # kernel_size=(1, 1, 1),
58 | # stride=1,
59 | # padding=0,
60 | # bias=False)
61 | weight = torch.ones(1, 8, 1, 1, 1).cuda()
62 | self.sum_conv.weight = nn.Parameter(weight)
63 | for param in self.sum_conv.parameters():
64 | param.requires_grad = False
65 |
66 | gate_wb, gate_sum = self.affinity_normalization(guidance)
67 |
68 | # pad input and convert to 8 channel 3D features
69 | raw_depth_input = blur_depth
70 |
71 | #blur_depht_pad = nn.ZeroPad2d((1,1,1,1))
72 | result_depth = blur_depth
73 |
74 |
75 |
76 | for i in range(self.prop_time):
77 | # one propagation
78 | spn_kernel = self.prop_kernel
79 | #print('11111111111111111111111')
80 | result_depth = self.pad_blur_depth(result_depth)
81 | neigbor_weighted_sum = self.sum_conv(gate_wb * result_depth)
82 | neigbor_weighted_sum = neigbor_weighted_sum.squeeze(1)
83 | neigbor_weighted_sum = neigbor_weighted_sum[:, :, 1:-1, 1:-1]
84 | result_depth = neigbor_weighted_sum
85 |
86 | if '8sum' in self.norm_type:
87 | result_depth = (1.0 - gate_sum) * raw_depth_input + result_depth
88 | else:
89 | raise ValueError('unknown norm %s' % self.norm_type)
90 |
91 |
92 | return result_depth
93 |
94 | def affinity_normalization(self, guidance):
95 |
96 | # normalize features
97 | if 'abs' in self.norm_type:
98 | guidance = torch.abs(guidance)
99 |
100 | gate1_wb_cmb = guidance.narrow(1, 0 , self.out_feature)
101 | gate2_wb_cmb = guidance.narrow(1, 1 * self.out_feature, self.out_feature)
102 | gate3_wb_cmb = guidance.narrow(1, 2 * self.out_feature, self.out_feature)
103 | gate4_wb_cmb = guidance.narrow(1, 3 * self.out_feature, self.out_feature)
104 | gate5_wb_cmb = guidance.narrow(1, 4 * self.out_feature, self.out_feature)
105 | gate6_wb_cmb = guidance.narrow(1, 5 * self.out_feature, self.out_feature)
106 | gate7_wb_cmb = guidance.narrow(1, 6 * self.out_feature, self.out_feature)
107 | gate8_wb_cmb = guidance.narrow(1, 7 * self.out_feature, self.out_feature)
108 |
109 | # gate1:left_top, gate2:center_top, gate3:right_top
110 | # gate4:left_center, , gate5: right_center
111 | # gate6:left_bottom, gate7: center_bottom, gate8: right_bottm
112 |
113 | # top pad
114 | left_top_pad = nn.ZeroPad2d((0,2,0,2))
115 | gate1_wb_cmb = left_top_pad(gate1_wb_cmb).unsqueeze(1)
116 |
117 | center_top_pad = nn.ZeroPad2d((1,1,0,2))
118 | gate2_wb_cmb = center_top_pad(gate2_wb_cmb).unsqueeze(1)
119 |
120 | right_top_pad = nn.ZeroPad2d((2,0,0,2))
121 | gate3_wb_cmb = right_top_pad(gate3_wb_cmb).unsqueeze(1)
122 |
123 | # center pad
124 | left_center_pad = nn.ZeroPad2d((0,2,1,1))
125 | gate4_wb_cmb = left_center_pad(gate4_wb_cmb).unsqueeze(1)
126 |
127 | right_center_pad = nn.ZeroPad2d((2,0,1,1))
128 | gate5_wb_cmb = right_center_pad(gate5_wb_cmb).unsqueeze(1)
129 |
130 | # bottom pad
131 | left_bottom_pad = nn.ZeroPad2d((0,2,2,0))
132 | gate6_wb_cmb = left_bottom_pad(gate6_wb_cmb).unsqueeze(1)
133 |
134 | center_bottom_pad = nn.ZeroPad2d((1,1,2,0))
135 | gate7_wb_cmb = center_bottom_pad(gate7_wb_cmb).unsqueeze(1)
136 |
137 | right_bottm_pad = nn.ZeroPad2d((2,0,2,0))
138 | gate8_wb_cmb = right_bottm_pad(gate8_wb_cmb).unsqueeze(1)
139 |
140 | gate_wb = torch.cat((gate1_wb_cmb,gate2_wb_cmb,gate3_wb_cmb,gate4_wb_cmb,
141 | gate5_wb_cmb,gate6_wb_cmb,gate7_wb_cmb,gate8_wb_cmb), 1)
142 |
143 | # normalize affinity using their abs sum
144 | gate_wb_abs = torch.abs(gate_wb)
145 | abs_weight = self.sum_conv(gate_wb_abs)
146 |
147 | gate_wb = torch.div(gate_wb, abs_weight)
148 | gate_sum = self.sum_conv(gate_wb)
149 |
150 | gate_sum = gate_sum.squeeze(1)
151 | gate_sum = gate_sum[:, :, 1:-1, 1:-1]
152 |
153 | return gate_wb, gate_sum
154 |
155 |
156 | def pad_blur_depth(self, blur_depth):
157 | # top pad
158 | left_top_pad = nn.ZeroPad2d((0,2,0,2))
159 | blur_depth_1 = left_top_pad(blur_depth).unsqueeze(1)
160 | center_top_pad = nn.ZeroPad2d((1,1,0,2))
161 | blur_depth_2 = center_top_pad(blur_depth).unsqueeze(1)
162 | right_top_pad = nn.ZeroPad2d((2,0,0,2))
163 | blur_depth_3 = right_top_pad(blur_depth).unsqueeze(1)
164 |
165 | # center pad
166 | left_center_pad = nn.ZeroPad2d((0,2,1,1))
167 | blur_depth_4 = left_center_pad(blur_depth).unsqueeze(1)
168 | right_center_pad = nn.ZeroPad2d((2,0,1,1))
169 | blur_depth_5 = right_center_pad(blur_depth).unsqueeze(1)
170 |
171 | # bottom pad
172 | left_bottom_pad = nn.ZeroPad2d((0,2,2,0))
173 | blur_depth_6 = left_bottom_pad(blur_depth).unsqueeze(1)
174 | center_bottom_pad = nn.ZeroPad2d((1,1,2,0))
175 | blur_depth_7 = center_bottom_pad(blur_depth).unsqueeze(1)
176 | right_bottm_pad = nn.ZeroPad2d((2,0,2,0))
177 | blur_depth_8 = right_bottm_pad(blur_depth).unsqueeze(1)
178 |
179 | result_depth = torch.cat((blur_depth_1, blur_depth_2, blur_depth_3, blur_depth_4,
180 | blur_depth_5, blur_depth_6, blur_depth_7, blur_depth_8), 1)
181 | return result_depth
182 |
183 |
184 | def normalize_gate(self, guidance):
185 | gate1_x1_g1 = guidance.narrow(1,0,1)
186 | gate1_x1_g2 = guidance.narrow(1,1,1)
187 | gate1_x1_g1_abs = torch.abs(gate1_x1_g1)
188 | gate1_x1_g2_abs = torch.abs(gate1_x1_g2)
189 | elesum_gate1_x1 = torch.add(gate1_x1_g1_abs, gate1_x1_g2_abs)
190 | gate1_x1_g1_cmb = torch.div(gate1_x1_g1, elesum_gate1_x1)
191 | gate1_x1_g2_cmb = torch.div(gate1_x1_g2, elesum_gate1_x1)
192 | return gate1_x1_g1_cmb, gate1_x1_g2_cmb
193 |
194 |
195 | def max_of_4_tensor(self, element1, element2, element3, element4):
196 | max_element1_2 = torch.max(element1, element2)
197 | max_element3_4 = torch.max(element3, element4)
198 | return torch.max(max_element1_2, max_element3_4)
199 |
200 | def max_of_8_tensor(self, element1, element2, element3, element4, element5, element6, element7, element8):
201 | max_element1_2 = self.max_of_4_tensor(element1, element2, element3, element4)
202 | max_element3_4 = self.max_of_4_tensor(element5, element6, element7, element8)
203 | return torch.max(max_element1_2, max_element3_4)
204 |
205 |
206 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.parallel
6 | import torch.optim as optim
7 | import torch.utils.data
8 | import torch.nn.functional as F
9 | import time
10 | from torch.autograd import Variable
11 | from dataloader import listflowfile as lt
12 | from dataloader import SecenFlowLoader as DA
13 | import utils.logger as logger
14 | from utils.flops_hook import profile
15 | from models.LWANet import *
16 |
17 |
18 | parser = argparse.ArgumentParser(description='LWANet with Sceneflow dataset')
19 | parser.add_argument('--maxdisp', type=int, default=192, help='maxium disparity')
20 | parser.add_argument('--loss_weights', type=float, nargs='+', default=[1., 1.])
21 | parser.add_argument('--maxdisplist', type=int, nargs='+', default=[24, 3, 3])
22 | parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
23 | parser.add_argument('--with_cspn', type =bool, default= True, help='with cspn network or not')
24 | parser.add_argument('--datapath', default='/data6/wsgan/SenceFlow/train/', help='datapath')
25 | parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train')
26 | parser.add_argument('--train_bsize', type=int, default=16, help='batch size for training (default: 12)')
27 | parser.add_argument('--test_bsize', type=int, default=8, help='batch size for testing (default: 8)')
28 | parser.add_argument('--save_path', type=str, default='./results/sceneflow/', help='the path of saving checkpoints and log')
29 | parser.add_argument('--resume', type=str, default=None, help='resume path')
30 | parser.add_argument('--print_freq', type=int, default=400, help='print frequence')
31 |
32 | parser.add_argument('--model_types', type=str, default='LWANet', help='model_types : 3D, P3D')
33 | parser.add_argument('--conv_3d_types1', type=str, default='P3D', help='model_types : 3D, P3D ')
34 | parser.add_argument('--conv_3d_types2', type=str, default='P3D', help='model_types : 3D, P3D')
35 | parser.add_argument('--cost_volume', type=str, default='Difference', help='cost_volume type : "Concat" , "Difference" or "Distance_based" ')
36 | parser.add_argument('--train', type =bool, default=True, help='train or test ')
37 |
38 |
39 | args = parser.parse_args()
40 |
41 | #CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py
42 |
43 |
44 | def main():
45 | global args
46 |
47 | train_left_img, train_right_img, train_left_disp, test_left_img, test_right_img, test_left_disp = lt.dataloader(
48 | args.datapath)
49 | TrainImgLoader = torch.utils.data.DataLoader(
50 | DA.myImageFloder(train_left_img, train_right_img, train_left_disp, True),
51 | batch_size=args.train_bsize, shuffle=True, num_workers=4, drop_last=False)
52 | TestImgLoader = torch.utils.data.DataLoader(
53 | DA.myImageFloder(test_left_img, test_right_img, test_left_disp, False),
54 | batch_size=args.test_bsize, shuffle=False, num_workers=4, drop_last=False)
55 |
56 | if not os.path.isdir(args.save_path):
57 | os.makedirs(args.save_path)
58 | log = logger.setup_logger(args.save_path + '/training.log')
59 | for key, value in sorted(vars(args).items()):
60 | log.info(str(key) + ': ' + str(value))
61 |
62 |
63 | model = LWANet(args)
64 |
65 |
66 | # FLOPs, params = count_flops(model.cuda())
67 | # log.info('Number of model parameters: {}'.format(params))
68 | # log.info('Number of model FLOPs: {}'.format(FLOPs))
69 |
70 |
71 | model = nn.DataParallel(model).cuda()
72 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
73 |
74 | args.start_epoch = 0
75 | if args.resume:
76 | if os.path.isfile(args.resume):
77 | log.info("=> loading checkpoint '{}'".format(args.resume))
78 | checkpoint = torch.load(args.resume)
79 | args.start_epoch = checkpoint['epoch']
80 | model.load_state_dict(checkpoint['state_dict'])
81 | optimizer.load_state_dict(checkpoint['optimizer'])
82 | log.info("=> loaded checkpoint '{}' (epoch {})"
83 | .format(args.resume, checkpoint['epoch']))
84 | else:
85 | log.info("=> no checkpoint found at '{}'".format(args.resume))
86 | log.info("=> Will start from scratch.")
87 | else:
88 | log.info('Not Resume')
89 |
90 | start_full_time = time.time()
91 |
92 | if args.train:
93 | for epoch in range(args.start_epoch, args.epochs):
94 | log.info('This is {}-th epoch'.format(epoch))
95 |
96 | train(TrainImgLoader, model, optimizer, log, epoch)
97 |
98 | savefilename = args.save_path + '/checkpoint_' + str(epoch) + '.tar'
99 |
100 | torch.save({
101 | 'epoch': epoch,
102 | 'state_dict': model.state_dict(),
103 | 'optimizer': optimizer.state_dict(),
104 | }, savefilename)
105 |
106 | if not epoch % 10:
107 | test(TestImgLoader, model, log)
108 |
109 | test(TestImgLoader, model, log)
110 | log.info('full training time = {:.2f} Hours'.format((time.time() - start_full_time) / 3600))
111 |
112 |
113 | def train(dataloader, model, optimizer, log, epoch=0):
114 |
115 |
116 | stages = 2
117 | losses = [AverageMeter() for _ in range(stages)]
118 | length_loader = len(dataloader)
119 |
120 | model.train()
121 |
122 | for batch_idx, (imgL, imgR, disp_L) in enumerate(dataloader):
123 | imgL = imgL.float().cuda()
124 | imgR = imgR.float().cuda()
125 | disp_L = disp_L.float().cuda()
126 |
127 | optimizer.zero_grad()
128 | mask = (disp_L < args.maxdisp) & (disp_L > 0)
129 | if mask.float().sum() == 0:
130 | continue
131 |
132 | mask.detach_()
133 |
134 | outputs, self_supervised_loss = model(imgL, imgR)
135 | stages = len(outputs)
136 |
137 | outputs = [torch.squeeze(output, 1) for output in outputs]
138 |
139 | loss = [args.loss_weights[x] * F.smooth_l1_loss(outputs[x][mask], disp_L[mask], size_average=True)
140 | for x in range(stages)]
141 |
142 | sum(loss).backward()
143 | optimizer.step()
144 |
145 | for idx in range(stages):
146 | losses[idx].update(loss[idx].item()/args.loss_weights[idx])
147 |
148 | if batch_idx % args.print_freq ==0:
149 | info_str = ['Stage {} = {:.2f}({:.2f})'.format(x, losses[x].val, losses[x].avg) for x in range(stages)]
150 | info_str = '\t'.join(info_str)
151 |
152 | log.info('Epoch{} [{}/{}] {}'.format(
153 | epoch, batch_idx, length_loader, info_str))
154 | info_str = '\t'.join(['Stage {} = {:.2f}'.format(x, losses[x].avg) for x in range(stages)])
155 | log.info('Average train loss = ' + info_str)
156 |
157 |
158 |
159 | def test(dataloader, model, log):
160 |
161 | stages = 2
162 | EPEs = [AverageMeter() for _ in range(stages)]
163 | length_loader = len(dataloader)
164 |
165 | model.eval()
166 |
167 | inference_time = 0
168 | for batch_idx, (imgL, imgR, disp_L) in enumerate(dataloader):
169 | imgL = imgL.float().cuda()
170 | imgR = imgR.float().cuda()
171 | disp_L = disp_L.float().cuda()
172 |
173 | mask = disp_L < args.maxdisp
174 | with torch.no_grad():
175 |
176 | time_start = time.perf_counter()
177 |
178 |
179 | outputs, monoloss = model(imgL, imgR)
180 |
181 | single_inference_time = time.perf_counter() - time_start
182 |
183 | inference_time += single_inference_time
184 |
185 |
186 | stages = len(outputs)
187 | for x in range(stages):
188 | if len(disp_L[mask]) == 0:
189 | EPEs[x].update(0)
190 | continue
191 | output = torch.squeeze(outputs[x], 1)
192 | output = output[:, 4:, :]
193 | EPEs[x].update((output[mask] - disp_L[mask]).abs().mean())
194 |
195 | if batch_idx % args.print_freq == 0:
196 | info_str = '\t'.join(['Stage {} = {:.2f}({:.2f})'.format(x, EPEs[x].val, EPEs[x].avg) for x in range(stages)])
197 |
198 | log.info('[{}/{}] {}'.format(
199 | batch_idx, length_loader, info_str))
200 |
201 | log.info(('=> Mean inference time for %d images: %.3fs' % (
202 | length_loader, inference_time / length_loader)))
203 |
204 | info_str = ', '.join(['Stage {}={:.2f}'.format(x, EPEs[x].avg) for x in range(stages)])
205 | log.info('Average test EPE = ' + info_str)
206 |
207 |
208 | def adjust_learning_rate(optimizer, epoch):
209 | if epoch <= 20:
210 | lr = args.lr
211 |
212 | elif 20 loaded pretrained model '{}'"
99 | .format(args.pretrained))
100 | else:
101 | log.info("=> no pretrained model found at '{}'".format(args.pretrained))
102 | log.info("=> Will start from scratch.")
103 | args.start_epoch = 0
104 |
105 |
106 | cudnn.benchmark = True
107 |
108 | if args.adaptation_type == "self_supervise":
109 | model.train()
110 | loss_file = open(args.save_path + '/self_supervise' + '.txt', 'w')
111 |
112 |
113 | elif args.adaptation_type == "GT_supervise":
114 | model.train()
115 | loss_file = open(args.save_path + '/GT_supervise' + '.txt', 'w')
116 |
117 |
118 | elif args.adaptation_type == "no_supervise":
119 |
120 | loss_file = open(args.save_path + '/no_supervise' + '.txt', 'w')
121 |
122 |
123 | train(TrainImgLoader, model, optimizer, log, loss_file, args)
124 |
125 |
126 |
127 | def train(dataloader, model, optimizer, log, loss_file, args):
128 |
129 | losses = [AverageMeter() for _ in range(2)]
130 | length_loader = len(dataloader)
131 | D1s = [AverageMeter() for _ in range(2)]
132 |
133 | start_full_time = time.time()
134 | for batch_idx, (imgL, imgR, disp_L) in enumerate(dataloader):
135 | imgL = imgL.float().cuda()
136 | imgR = imgR.float().cuda()
137 | disp_L = disp_L.float().cuda()
138 | #print('train imgR size:', imgR.shape)
139 |
140 | optimizer.zero_grad()
141 | mask = disp_L > 0
142 | mask = mask*(disp_L<192)
143 | mask.detach_()
144 |
145 | single_update_time=time.time()
146 |
147 | #outputs = model(imgL, imgR)
148 | if args.adaptation_type == "no_supervise":
149 | model.eval()
150 | with torch.no_grad():
151 | pred, mono_loss = model(imgL, imgR)
152 |
153 | outputs = [torch.squeeze(output, 1) for output in pred]
154 |
155 | num_out = len(pred)
156 | loss = [args.loss_weights[x] * F.smooth_l1_loss(outputs[x][mask], disp_L[mask], size_average=True)
157 | for x in range(num_out)]
158 |
159 |
160 | num_out = len(pred)
161 |
162 |
163 | elif args.adaptation_type == "self_supervise":
164 | model.train()
165 |
166 | pred, mono_loss = model(imgL, imgR)
167 | outputs = [torch.squeeze(output, 1) for output in pred]
168 | num_out = len(pred)
169 | loss = [args.loss_weights[x] * F.smooth_l1_loss(outputs[x][mask], disp_L[mask], size_average=True)
170 | for x in range(num_out)]
171 |
172 | sum(mono_loss).backward()
173 |
174 | optimizer.step()
175 |
176 | elif args.adaptation_type == "GT_supervise":
177 | model.train()
178 |
179 | pred, mono_loss = model(imgL, imgR)
180 |
181 | outputs = [torch.squeeze(output, 1) for output in pred]
182 |
183 | num_out = len(pred)
184 | loss = [args.loss_weights[x] * F.smooth_l1_loss(outputs[x][mask], disp_L[mask], size_average=True)
185 | for x in range(num_out)]
186 |
187 | sum(loss).backward()
188 | optimizer.step()
189 |
190 |
191 |
192 | print('sigle_update_time: {:.4f} seconds'.format(time.time() - single_update_time))
193 | # image out and error estimation
194 |
195 | # three pixel error
196 |
197 | output = torch.squeeze(pred[1], 1)
198 | D1s[1].update(error_estimating(output, disp_L).item())
199 | print('output size:', output.shape)
200 |
201 |
202 |
203 | # save the adaptation disparity
204 | if args.save_disparity :
205 |
206 | plt.imshow(output.squeeze(0).cpu().detach().numpy())
207 | plt.axis('off')
208 |
209 | plt.gcf().set_size_inches(1216 / 100, 320 / 100)
210 | plt.gca().xaxis.set_major_locator(plt.NullLocator())
211 | plt.gca().yaxis.set_major_locator(plt.NullLocator())
212 | plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
213 | plt.margins(0, 0)
214 |
215 | plt.savefig(args.save_path+'/disparity/{}.png'.format(batch_idx))
216 |
217 | # if args.save_disparity:
218 | #
219 | # imgL = imgL.squeeze(0).permute(1,2,0)
220 | # #print("imgL size:", imgL.shape)
221 | # plt.imshow(imgL.cpu().detach().numpy())
222 | # plt.axis('off')
223 | #
224 | # plt.gcf().set_size_inches(1216 / 100, 320 / 100)
225 | # plt.gca().xaxis.set_major_locator(plt.NullLocator())
226 | # plt.gca().yaxis.set_major_locator(plt.NullLocator())
227 | # plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
228 | # plt.margins(0, 0)
229 | #
230 | # plt.savefig(args.save_path + '/disparity/{}.png'.format(batch_idx))
231 | #
232 |
233 |
234 |
235 | loss_file.write('{:.4f}\n'.format(D1s[1].val))
236 |
237 | for idx in range(num_out):
238 | losses[idx].update(loss[idx].item())
239 |
240 |
241 | info_str = ['Stage {} = {:.2f}({:.2f})'.format(x, losses[x].val, losses[x].avg) for x in range(num_out)]
242 | info_str = '\t'.join(info_str)
243 |
244 | log.info('Epoch{} [{}/{}] {}'.format( 1, batch_idx, length_loader, info_str))
245 |
246 | end_time = time.time()
247 |
248 | log.info('full training time = {:.2f} Hours, full train time = {:.4f} seconds'.format(
249 | (end_time - start_full_time) / 3600, end_time - start_full_time))
250 |
251 | # summary
252 | info_str = ', '.join(['Stage {}={:.4f}'.format(x, D1s[x].avg) for x in range(num_out)])
253 |
254 | log.info('Average test 3-Pixel Error = ' + info_str)
255 |
256 | info_str = '\t'.join(['Stage {} = {:.2f}'.format(x, losses[x].avg) for x in range(num_out)])
257 | log.info('Average train loss = ' + info_str)
258 |
259 | loss_file.close()
260 |
261 |
262 |
263 | def error_estimating(disp, ground_truth, maxdisp=192):
264 |
265 | gt = ground_truth
266 | mask = gt > 0
267 | mask = mask * (gt < maxdisp)
268 |
269 | errmap = torch.abs(disp - gt)
270 | err3 = ((errmap[mask] > 3.) & (errmap[mask] / gt[mask] > 0.05)).sum()
271 | return err3.float() / mask.sum().float()
272 |
273 |
274 |
275 | class AverageMeter(object):
276 | """Computes and stores the average and current value"""
277 |
278 | def __init__(self):
279 | self.reset()
280 |
281 | def reset(self):
282 | self.val = 0
283 | self.avg = 0
284 | self.sum = 0
285 | self.count = 0
286 |
287 | def update(self, val, n=1):
288 | self.val = val
289 | self.sum += val * n
290 | self.count += n
291 | self.avg = self.sum / self.count
292 |
293 |
294 | if __name__ == '__main__':
295 | main()
296 |
297 |
298 |
299 |
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.parallel
6 | import torch.optim as optim
7 | import torch.utils.data
8 | import torch.nn.functional as F
9 | import time
10 | from dataloader import KITTILoader as DA
11 | from dataloader import KITTIdatalist as ls
12 | import utils.logger as logger
13 | import torch.backends.cudnn as cudnn
14 | from models.LWANet import *
15 |
16 | import pdb
17 |
18 | # 查看GPU使用情況
19 | # watch --color -n1 gpustat -cpu
20 |
21 |
22 | parser = argparse.ArgumentParser(description='LWANet fintune on KITTI')
23 | parser.add_argument('--maxdisp', type=int, default=192, help='maxium disparity')
24 | parser.add_argument('--loss_weights', type=float, nargs='+', default=[1., 1.])
25 | parser.add_argument('--max_disparity', type=int, default=192)
26 | parser.add_argument('--maxdisplist', type=int, nargs='+', default=[24, 3, 3])
27 | parser.add_argument('--with_cspn', type =bool, default= True, help='with cspn network or not')
28 | parser.add_argument('--cost_volume', type=str, default='Difference', help='cost_volume type : "Concat" , "Difference" or "Distance_based"')
29 | parser.add_argument('--lr', type=float, default=5e-4*0.5, help='learning rate')
30 | parser.add_argument('--epochs', type=int, default=1001, help='number of epochs to train')
31 | parser.add_argument('--train_bsize', type=int, default=8, help='batch size for training (default: 8)')
32 | parser.add_argument('--test_bsize', type=int, default=8,help='batch size for testing (default: 8)')
33 | parser.add_argument('--resume', type=str, default= None, help='resume path')
34 | parser.add_argument('--print_freq', type=int, default=10, help='print frequence')
35 | parser.add_argument('--pretrained', type=str, default=None, help='pretrained model path')
36 | parser.add_argument('--model_types', type=str, default='LWANet', help='model_types : 3D OR P3D')
37 | parser.add_argument('--conv_3d_types1', type=str, default='P3D', help='model_types : 3D, P3D ')
38 | parser.add_argument('--conv_3d_types2', type=str, default='P3D', help='model_types : 3D, P3D')
39 |
40 |
41 | parser.add_argument('--save_path', type=str, default='/results/finetune2015/',help='the path of saving checkpoints and log')
42 | parser.add_argument('--split_for_val', type =bool, default=False, help='finetune for submission or for validation')
43 | parser.add_argument('--datatype', default='mix', help='finetune dataset: 2012, 2015, mix')
44 | parser.add_argument('--datapath2015', default='/data6/wsgan/KITTI/KITTI2015/training/', help='datapath')
45 | parser.add_argument('--datapath2012', default='/data6/wsgan/KITTI/KITTI2012/training/', help='datapath')
46 |
47 |
48 | args = parser.parse_args()
49 |
50 |
51 | #CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python finetune.py
52 |
53 | def main():
54 | global args
55 | log = logger.setup_logger(args.save_path + '/training.log')
56 |
57 | if args.datatype == '2015':
58 |
59 | all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader2015(
60 | args.datapath2015, log, split=args.split_for_val)
61 |
62 | elif args.datatype == '2012':
63 |
64 | all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader2012(
65 | args.datapath2012, log, split = False)
66 |
67 | elif args.datatype == 'mix':
68 |
69 | all_left_img_2015, all_right_img_2015, all_left_disp_2015, test_left_img_2015, test_right_img_2015, test_left_disp_2015 = ls.dataloader2015(
70 | args.datapath2015, log, split=False)
71 | all_left_img_2012, all_right_img_2012, all_left_disp_2012, test_left_img_2012, test_right_img_2012, test_left_disp_2012 = ls.dataloader2012(
72 | args.datapath2012, log, split=False)
73 | all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = \
74 | all_left_img_2015 + all_left_img_2012, all_right_img_2015 + all_right_img_2012, \
75 | all_left_disp_2015 + all_left_disp_2012, test_left_img_2015 + test_left_img_2012, \
76 | test_right_img_2015 + test_right_img_2012, test_left_disp_2015 + test_left_disp_2012
77 | else:
78 |
79 | AssertionError("please define the finetune dataset")
80 |
81 | TrainImgLoader = torch.utils.data.DataLoader(
82 | DA.myImageFloder(all_left_img, all_right_img, all_left_disp, True),
83 | batch_size=args.train_bsize, shuffle=True, num_workers=4, drop_last=False)
84 |
85 | TestImgLoader = torch.utils.data.DataLoader(
86 | DA.myImageFloder(test_left_img, test_right_img, test_left_disp, False),
87 | batch_size=args.test_bsize, shuffle=False, num_workers=4, drop_last=False)
88 |
89 | if not os.path.isdir(args.save_path):
90 | os.makedirs(args.save_path)
91 | for key, value in sorted(vars(args).items()):
92 | log.info(str(key) + ': ' + str(value))
93 |
94 |
95 | model = LWANet(args)
96 |
97 |
98 | model = nn.DataParallel(model).cuda()
99 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
100 | log.info('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
101 |
102 | if args.pretrained:
103 | if os.path.isfile(args.pretrained):
104 | checkpoint = torch.load(args.pretrained)
105 | model.load_state_dict(checkpoint['state_dict'], strict=False)
106 | log.info("=> loaded pretrained model '{}'"
107 | .format(args.pretrained))
108 | else:
109 | log.info("=> no pretrained model found at '{}'".format(args.pretrained))
110 | log.info("=> Will start from scratch.")
111 | args.start_epoch = 0
112 | if args.resume:
113 | if os.path.isfile(args.resume):
114 | log.info("=> loading checkpoint '{}'".format(args.resume))
115 | checkpoint = torch.load(args.resume)
116 | model.load_state_dict(checkpoint['state_dict'], strict=False)
117 | optimizer.load_state_dict(checkpoint['optimizer'])
118 | args.start_epoch = checkpoint['epoch'] + 1
119 | log.info("=> loaded checkpoint '{}' (epoch {})"
120 | .format(args.resume, checkpoint['epoch']))
121 | else:
122 | log.info("=> no checkpoint found at '{}'".format(args.resume))
123 | log.info("=> Will start from scratch.")
124 | else:
125 | log.info('Not Resume')
126 | cudnn.benchmark = True
127 |
128 | start_full_time = time.time()
129 |
130 |
131 |
132 | for epoch in range(args.start_epoch, args.epochs):
133 | log.info('This is {}-th epoch'.format(epoch))
134 | adjust_learning_rate(optimizer, epoch)
135 |
136 | train(TrainImgLoader, model, optimizer, log, epoch)
137 |
138 | if epoch % 100 == 0:
139 | savefilename = args.save_path + '/finetune_' + str(epoch) + '.tar'
140 | torch.save({
141 | 'epoch': epoch,
142 | 'state_dict': model.state_dict(),
143 | 'optimizer': optimizer.state_dict(),
144 | }, savefilename)
145 |
146 |
147 |
148 | if epoch % 20 == 0:
149 | test(TestImgLoader, model, log)
150 |
151 |
152 |
153 | test(TestImgLoader, model, log)
154 | log.info('full training time = {:.2f} Hours'.format((time.time() - start_full_time) / 3600))
155 |
156 |
157 |
158 | def train(dataloader, model, optimizer, log, epoch=0):
159 |
160 | stages = 2
161 | losses = [AverageMeter() for _ in range(stages)]
162 | length_loader = len(dataloader)
163 |
164 | model.train()
165 |
166 | for batch_idx, (imgL, imgR, disp_L) in enumerate(dataloader):
167 |
168 | imgL = imgL.float().cuda()
169 | imgR = imgR.float().cuda()
170 | disp_L = disp_L.float().cuda()
171 |
172 | optimizer.zero_grad()
173 | mask = (disp_L > 0) & (disp_L < args.maxdisp)
174 | mask.detach_()
175 |
176 | pred, mono_loss = model(imgL, imgR)
177 |
178 | outputs = [torch.squeeze(output, 1) for output in pred]
179 |
180 | num_out = len(pred)
181 | loss = [args.loss_weights[x] * F.smooth_l1_loss(outputs[x][mask], disp_L[mask], size_average=True)
182 | for x in range(num_out)]
183 |
184 | sum(loss).backward()
185 |
186 | optimizer.step()
187 |
188 | for idx in range(num_out):
189 | losses[idx].update(loss[idx].item())
190 |
191 | if batch_idx % args.print_freq == 0:
192 | info_str = ['Stage {} = {:.2f}({:.2f})'.format(x, losses[x].val, losses[x].avg) for x in range(num_out)]
193 | info_str = '\t'.join(info_str)
194 |
195 | log.info('Epoch{} [{}/{}] {}'.format(
196 | epoch, batch_idx, length_loader, info_str))
197 |
198 | info_str = '\t'.join(['Stage {} = {:.2f}'.format(x, losses[x].avg) for x in range(1)])
199 | log.info('Average train loss = ' + info_str)
200 |
201 |
202 | def test(dataloader, model, log):
203 |
204 | stages = 3 + args.with_cspn
205 | D1s = [AverageMeter() for _ in range(stages)]
206 | length_loader = len(dataloader)
207 |
208 | model.eval()
209 |
210 | total_inference_time = 0
211 | for batch_idx, (imgL, imgR, disp_L) in enumerate(dataloader):
212 |
213 | imgL = imgL.float().cuda()
214 | imgR = imgR.float().cuda()
215 | disp_L = disp_L.float().cuda()
216 |
217 | with torch.no_grad():
218 |
219 | start_time = time.time()
220 | outputs, mono_loss = model(imgL, imgR)
221 | print(time.time() - start_time)
222 | total_inference_time += time.time() - start_time
223 |
224 | num_out = len(outputs)
225 | for x in range(num_out):
226 |
227 | output = torch.squeeze(outputs[x], 1)
228 | D1s[x].update(error_estimating(output, disp_L).item())
229 |
230 | info_str = '\t'.join(['Stage {} = {:.4f}({:.4f})'.format(x, D1s[x].val, D1s[x].avg) for x in range(num_out)])
231 |
232 | log.info('[{}/{}] {}'.format( batch_idx, length_loader, info_str))
233 |
234 | log.info("mean inference time: %.3fs " % (total_inference_time / length_loader))
235 | info_str = ', '.join(['Stage {}={:.4f}'.format(x, D1s[x].avg) for x in range(num_out)])
236 | log.info('Average test 3-Pixel Error = ' + info_str)
237 |
238 |
239 | def error_estimating(disp, ground_truth, maxdisp=192):
240 |
241 | gt = ground_truth
242 | mask = gt > 0
243 | mask = mask * (gt < maxdisp)
244 | errmap = torch.abs(disp - gt)
245 | err3 = ((errmap[mask] > 3.) & (errmap[mask] / gt[mask] > 0.05)).sum()
246 |
247 | return err3.float() / mask.sum().float()
248 |
249 |
250 | def adjust_learning_rate(optimizer, epoch):
251 | if epoch <= 600:
252 | lr = args.lr
253 |
254 | elif 600< epoch <= 1000:
255 | lr = args.lr*0.1
256 |
257 | else:
258 | lr = args.lr*0.01
259 |
260 | for param_group in optimizer.param_groups:
261 | param_group['lr'] = lr
262 |
263 |
264 |
265 | class AverageMeter(object):
266 | """Computes and stores the average and current value"""
267 |
268 | def __init__(self):
269 | self.reset()
270 |
271 | def reset(self):
272 | self.val = 0
273 | self.avg = 0
274 | self.sum = 0
275 | self.count = 0
276 |
277 | def update(self, val, n=1):
278 | self.val = val
279 | self.sum += val * n
280 | self.count += n
281 | self.avg = self.sum / self.count
282 |
283 |
284 |
285 | if __name__ == '__main__':
286 | main()
287 |
288 |
289 |
290 |
--------------------------------------------------------------------------------
/One_cycle.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch.nn.parallel
4 | import torch.optim as optim
5 | import torch.utils.data
6 | import torch.nn.functional as F
7 | import time
8 | import matplotlib.pyplot as plt
9 | from dataloader import KITTILoader_One_cycle as DA
10 | import utils.logger as logger
11 | import torch.backends.cudnn as cudnn
12 |
13 | import models
14 |
15 |
16 | from models.LWADNet import *
17 |
18 |
19 | parser = argparse.ArgumentParser(description='Anynet fintune on KITTI')
20 | parser.add_argument('--maxdisp', type=int, default=192,
21 | help='maxium disparity')
22 | parser.add_argument('--loss_weights', type=float, nargs='+', default=[1., 1., 1., 1.])
23 | parser.add_argument('--max_disparity', type=int, default=192)
24 | parser.add_argument('--maxdisplist', type=int, nargs='+', default=[24, 3, 3])
25 | parser.add_argument('--datatype', default='2015',
26 | help='datapath')
27 | parser.add_argument('--datapath', default='/home/wsgan/KITTI_DATASET/KITTI2015/training/', help='datapath')
28 | #parser.add_argument('--datapath', default='/home/um/GAN/Anynet/kitti2012/training/', help='datapath')
29 |
30 | parser.add_argument('--epochs', type=int, default=200,
31 | help='number of epochs to train')
32 | parser.add_argument('--train_bsize', type=int, default=1,
33 | help='batch size for training (default: 6)')
34 | parser.add_argument('--test_bsize', type=int, default=1,
35 | help='batch size for testing (default: 8)')
36 |
37 |
38 | parser.add_argument('--lr', type=float, default=5e-4,
39 | help='learning rate')
40 | parser.add_argument('--with_cspn', action='store_true', help='with spn network or not')
41 |
42 |
43 | parser.add_argument('--model_types', type=str, default='original', help='model_types : LWANet_3D, mix, original')
44 | parser.add_argument('--conv_3d_types1', type=str, default='separate_only', help='model_types : normal, P3D, separate_only, ONLY_2D ')
45 | parser.add_argument('--conv_3d_types2', type=str, default='separate_only', help='model_types : normal, P3D, separate_only, ONLY_2D')
46 | parser.add_argument('--cost_volume', type=str, default='Difference', help='cost_volume type : "Concat" , "Difference" or "Distance_based" ')
47 |
48 |
49 | parser.add_argument('--adaptation_type', default='GT_supervise', help='adaptation_type : self_supervise, GT_supervise, no_supervise')
50 |
51 | parser.add_argument('--pretrained', type=str, default='/home/wsgan/LWANet/results/pretrain/original_Difference/separate_only/checkpoint_49.tar',
52 | help='pretrained model path')
53 | parser.add_argument('--save_path', type=str, default='./results/finetune_One_cycle/GT_supervise/',
54 | help='the path of saving checkpoints and log')
55 |
56 |
57 |
58 | args = parser.parse_args()
59 |
60 | if args.datatype == '2015':
61 | from dataloader import KITTIloader2015_One_cycle as ls
62 |
63 | elif args.datatype == '2012':
64 | from dataloader import KITTIloader2012 as ls
65 |
66 |
67 |
68 | # python One_cycle.py --with_cspn
69 |
70 | def main():
71 | global args
72 | log = logger.setup_logger(args.save_path + '/training.log')
73 | #log1 = logger.setup_logger(args.save_path + '/self_adaptive_loss.log')
74 |
75 | train_left_img, train_right_img, train_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader(
76 | args.datapath,log)
77 |
78 | TrainImgLoader = torch.utils.data.DataLoader(
79 | DA.myImageFloder(train_left_img, train_right_img, train_left_disp, True),
80 | batch_size=args.train_bsize, shuffle=True, num_workers=4, drop_last=False)
81 |
82 | TestImgLoader = torch.utils.data.DataLoader(
83 | DA.myImageFloder(test_left_img, test_right_img, test_left_disp, False),
84 | batch_size=args.test_bsize, shuffle=False, num_workers=4, drop_last=False)
85 |
86 | if not os.path.isdir(args.save_path):
87 | os.makedirs(args.save_path)
88 |
89 | if not os.path.isdir(args.save_path+'/image'):
90 | os.makedirs(args.save_path+'/image')
91 |
92 | for key, value in sorted(vars(args).items()):
93 | log.info(str(key) + ': ' + str(value))
94 |
95 |
96 | model = models.LWADNet.AnyNet(args)
97 |
98 |
99 | model = nn.DataParallel(model).cuda()
100 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
101 | log.info('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
102 |
103 | if args.pretrained:
104 | if os.path.isfile(args.pretrained):
105 | checkpoint = torch.load(args.pretrained)
106 | model.load_state_dict(checkpoint['state_dict'], strict=False)
107 | log.info("=> loaded pretrained model '{}'"
108 | .format(args.pretrained))
109 | else:
110 | log.info("=> no pretrained model found at '{}'".format(args.pretrained))
111 | log.info("=> Will start from scratch.")
112 | args.start_epoch = 0
113 |
114 | cudnn.benchmark = True
115 |
116 | start_full_time = time.time()
117 | loss_file = open(args.save_path + '/self_supervise' + '.txt', 'w')
118 |
119 | for epoch in range(args.start_epoch, args.epochs):
120 | log.info('This is {}-th epoch'.format(epoch))
121 |
122 | D1s= train(TrainImgLoader, model, optimizer, log, epoch)
123 | loss_file.write('{:.4f}\n'.format(D1s))
124 |
125 |
126 | loss_file.close()
127 |
128 | log.info('full training time = {:.2f} Hours'.format((time.time() - start_full_time) / 3600))
129 |
130 |
131 | def train(dataloader, model, optimizer, log, epoch=0):
132 |
133 |
134 |
135 | stages = 3 + args.with_cspn
136 | losses = [AverageMeter() for _ in range(stages)]
137 | length_loader = len(dataloader)
138 | D1s = [AverageMeter() for _ in range(2)]
139 |
140 |
141 | model.train()
142 |
143 | #loss_file = open(args.save_path + '/self_adaptive_loss' + '.txt', 'w')
144 |
145 | for batch_idx, (imgL, imgR, disp_L) in enumerate(dataloader):
146 | imgL = imgL.float().cuda()
147 | imgR = imgR.float().cuda()
148 | disp_L = disp_L.float().cuda()
149 | #print(' disp_L size:', disp_L)
150 |
151 |
152 |
153 | optimizer.zero_grad()
154 | mask = disp_L > 0
155 | mask.detach_()
156 |
157 | #outputs = model(imgL, imgR)
158 | pred, mono_loss = model(imgL, imgR)
159 |
160 | for x in range(len(pred)):
161 | output = torch.squeeze(pred[x], 1)
162 | D1s[x].update(error_estimating(output, disp_L).item())
163 |
164 | # loss_file.write('{:.4f}\n'.format(D1s[1].val))
165 | # loss_file.close()
166 |
167 | # print('len(outputs)', len(outputs))
168 | pred = [pred for pred in pred]
169 | num_out = len(pred)
170 | #print('num_out:', num_out)
171 |
172 |
173 | outputs = [torch.squeeze(output, 1) for output in pred]
174 |
175 | output_save = outputs[1].squeeze(0)
176 | #print('output_save:', output_save.shape)
177 |
178 | #io.imsave(args.save_path + '/epoch {}.png'.format(epoch), (output_save.cpu().data.numpy() ))
179 |
180 | plt.imshow(output_save.detach().cpu().numpy())
181 | plt.axis('off')
182 |
183 | #plt.savefig(args.save_path+'/image'+ '/epoch {} D1 {:.4f}.png'.format(epoch, D1s[1].val))
184 | plt.savefig(args.save_path + '/image' + '/epoch {} D1 {:.4f}.png'.format(epoch, D1s[1].val), bbox_inches = 'tight', dpi= 300, pad_inches = 0)
185 |
186 |
187 | loss = [args.loss_weights[x] * F.smooth_l1_loss(outputs[x][mask], disp_L[mask], size_average=True)
188 | for x in range(num_out)]
189 |
190 | #if args.adaptation_type == "no_supervise":
191 |
192 | #sum(mono_loss).backward()
193 | sum(loss).backward()
194 | #
195 | optimizer.step()
196 |
197 | for idx in range(num_out):
198 | losses[idx].update(loss[idx].item())
199 |
200 | if 1:
201 | info_str = ['Stage {} = {:.2f}({:.2f})'.format(x, losses[x].val, losses[x].avg) for x in range(num_out)]
202 | info_str = '\t'.join(info_str)
203 |
204 | log.info('Epoch{} [{}/{}] {}'.format(
205 | epoch, batch_idx, length_loader, info_str))
206 |
207 | info_str = '\t'.join(
208 | ['Stage {} = {:.4f}({:.4f})'.format(x, D1s[x].val, D1s[x].avg) for x in range(num_out)])
209 |
210 | log.info('[{}/{}] {}'.format(
211 | batch_idx, length_loader, info_str))
212 |
213 | return D1s[1].val
214 |
215 |
216 | # info_str = '\t'.join(['Stage {} = {:.2f}'.format(x, losses[x].avg) for x in range(stages)])
217 | # info_str = '\t'.join(['Stage {} = {:.2f}'.format(x, losses[x].avg) for x in range(2)])
218 | # log.info('Average train loss = ' + info_str)
219 |
220 |
221 | def test(dataloader, model, log):
222 |
223 | stages = 3 + args.with_cspn
224 | D1s = [AverageMeter() for _ in range(stages)]
225 | length_loader = len(dataloader)
226 |
227 | model.eval()
228 |
229 | for batch_idx, (imgL, imgR, disp_L) in enumerate(dataloader):
230 | imgL = imgL.float().cuda()
231 | imgR = imgR.float().cuda()
232 | disp_L = disp_L.float().cuda()
233 | # print('test imgR size:', imgR.shape)
234 |
235 | # imgL = F.pad(imgL, [3, 3, 1, 0])
236 | # imgR = F.pad(imgR, [3, 3, 1, 0])
237 | # disp_L = F.pad(disp_L, [3, 3, 1, 0])
238 | #print('imgR size:', imgR.shape)
239 |
240 | with torch.no_grad():
241 | outputs, mono_loss = model(imgL, imgR, train = 0)
242 |
243 |
244 | # for x in range(stages):
245 | if args.with_cspn:
246 | # if epoch >= args.start_epoch_for_spn:
247 | # num_out = len(outputs)
248 | # else:
249 | # num_out = len(outputs) - 1
250 | num_out = len(outputs)
251 |
252 | else:
253 | num_out = len(outputs)
254 |
255 | for x in range(num_out):
256 | output = torch.squeeze(outputs[x], 1)
257 |
258 | # print('output size:', output.shape)
259 | # print('disp_L size:', disp_L.shape)
260 | D1s[x].update(error_estimating(output, disp_L).item())
261 |
262 |
263 | info_str = '\t'.join(['Stage {} = {:.4f}({:.4f})'.format(x, D1s[x].val, D1s[x].avg) for x in range(num_out)])
264 |
265 |
266 | log.info('[{}/{}] {}'.format(
267 | batch_idx, length_loader, info_str))
268 |
269 |
270 | info_str = ', '.join(['Stage {}={:.4f}'.format(x, D1s[x].avg) for x in range(num_out)])
271 |
272 | log.info('Average test 3-Pixel Error = ' + info_str)
273 |
274 |
275 | def error_estimating(disp, ground_truth, maxdisp=192):
276 | gt = ground_truth
277 |
278 |
279 | # gt = gt[:, 0:368, 50:1200]
280 | # disp = disp[:, 0:368, 50:1200]
281 | # print('gt shape:', gt.shape)
282 |
283 | #mask = gt[:, 0:368, 50:1232]> 0
284 |
285 | mask = gt > 0
286 | mask = mask * (gt < maxdisp)
287 |
288 | errmap = torch.abs(disp - gt)
289 | err3 = ((errmap[mask] > 3.) & (errmap[mask] / gt[mask] > 0.05)).sum()
290 | return err3.float() / mask.sum().float()
291 |
292 |
293 | def adjust_learning_rate(optimizer, epoch):
294 | if epoch <= 1000:
295 | lr = args.lr
296 | elif epoch <= 1500:
297 | lr = args.lr * 0.1
298 | else:
299 | lr = args.lr * 0.01
300 | for param_group in optimizer.param_groups:
301 | param_group['lr'] = lr
302 |
303 | class AverageMeter(object):
304 | """Computes and stores the average and current value"""
305 |
306 | def __init__(self):
307 | self.reset()
308 |
309 | def reset(self):
310 | self.val = 0
311 | self.avg = 0
312 | self.sum = 0
313 | self.count = 0
314 |
315 | def update(self, val, n=1):
316 | self.val = val
317 | self.sum += val * n
318 | self.count += n
319 | self.avg = self.sum / self.count
320 |
321 |
322 | def post_process_disparity(disp):
323 | _, h, w = disp[0].shape
324 | #print('disp[0].shape:', disp[0].shape) # torch.Size([1, 368, 1232])
325 |
326 | l_disp = disp[0].cpu().numpy()
327 | #r_disp = np.fliplr(disp[1].cpu())
328 | r_disp = disp[1].cpu().numpy()
329 |
330 | #m_disp = 0.5 * (l_disp + r_disp)
331 |
332 | l, _ = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h))
333 | l_mask = 1.0 - np.clip(20 * (l - 0.05), 0, 1)
334 | #r_mask =np.fliplr(l_mask)
335 | # return r_mask * l_disp + l_mask * r_disp + (1.0 - l_mask - r_mask) * m_disp
336 | return l_mask * r_disp + (1.0 - l_mask ) * l_disp
337 | # benlaijiushi l_disp zhijiequdiao
338 |
339 |
340 |
341 |
342 |
343 | if __name__ == '__main__':
344 | main()
345 |
346 |
347 |
348 |
--------------------------------------------------------------------------------