├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── docs
├── CVPR_2019_poster.jpg
├── CVPR_2019_poster.pdf
├── cvpr2018-pipeline.png
├── index.html
├── main_result.png
├── pdf_thumbnail.jpg
└── table1_caption.png
├── figures
├── cvpr2018-pipeline.png
└── result.png
├── preprocessing
├── generate_disp.py
├── generate_lidar.py
├── kitti_process_RANSAC.py
└── kitti_util.py
├── psmnet
├── README.md
├── dataloader
│ ├── KITTILoader.py
│ ├── KITTILoader3D.py
│ ├── KITTILoader_dataset3d.py
│ ├── KITTI_submission_loader.py
│ ├── KITTI_submission_loader2012.py
│ ├── KITTIloader2012.py
│ ├── KITTIloader2015.py
│ ├── SecenFlowLoader.py
│ ├── __init__.py
│ ├── listflowfile.py
│ ├── preprocess.py
│ └── readpfm.py
├── finetune_3d.py
├── logger.py
├── models
│ ├── __init__.py
│ ├── basic.py
│ ├── stackhourglass.py
│ └── submodule.py
├── submission.py
└── utils
│ ├── __init__.py
│ ├── preprocess.py
│ └── readpfm.py
└── visualization
├── 000012.bin
├── pyntcloud.ipynb
└── pyntcloud.png
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "avod"]
2 | path = avod
3 | url = git@github.com:mileyan/avod.git
4 |
5 | [submodule "frustum-pointnets"]
6 | path = frustum-pointnets
7 | url = git@github.com:charlesq34/frustum-pointnets.git
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Yan (Eric) Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pseudo-LiDAR from Visual Depth Estimation: Bridging the Gap in 3D Object Detection for Autonomous Driving
2 | This paper has been accpeted by Conference on Computer Vision and Pattern Recognition ([CVPR](http://cvpr2019.thecvf.com/)) 2019.
3 |
4 | [
5 | Pseudo-LiDAR from Visual Depth Estimation: Bridging the Gap in 3D Object Detection for Autonomous Driving](https://arxiv.org/abs/1812.07179)
6 |
7 | by [Yan Wang](https://www.cs.cornell.edu/~yanwang/), [Wei-Lun Chao](http://www-scf.usc.edu/~weilunc/), [Divyansh Garg](http://divyanshgarg.com/), [Bharath Hariharan](http://home.bharathh.info/), [Mark Campbell](https://campbell.mae.cornell.edu/) and [Kilian Q. Weinberger](http://kilian.cs.cornell.edu/)
8 |
9 | 
10 | ### Citation
11 | ```
12 | @inproceedings{wang2019pseudo,
13 | title={Pseudo-LiDAR from Visual Depth Estimation: Bridging the Gap in 3D Object Detection for Autonomous Driving},
14 | author={Wang, Yan and Chao, Wei-Lun and Garg, Divyansh and Hariharan, Bharath and Campbell, Mark and Weinberger, Kilian},
15 | booktitle={CVPR},
16 | year={2019}
17 | }
18 | ```
19 | ## Update
20 | * 2nd July 2020: Add a jupyter script to visualize point cloud. It is in ./visualization folder.
21 | * 29th July 2019: `submission.py` will save the disparity to the numpy file, not png file. And fix the `generate_lidar.py`.
22 | * I have modifed the official avod a little bit. Now you can directly train and test pseudo-lidar with avod. Please check the code https://github.com/mileyan/avod_pl.
23 |
24 | ## Contents
25 |
26 | - [Introduction](#introduction)
27 | - [Usage](#usage)
28 | - [Results](#results)
29 | - [Contacts](#contacts)
30 |
31 | ## Introduction
32 | 3D object detection is an essential task in autonomous driving. Recent techniques excel with highly accurate detection rates, provided the 3D input data is obtained from precise but expensive LiDAR technology. Approaches based on cheaper monocular or stereo imagery data have, until now, resulted in drastically lower accuracies --- a gap that is commonly attributed to poor image-based depth estimation. However, in this paper we argue that data representation (rather than its quality) accounts for the majority of the difference. Taking the inner workings of convolutional neural networks into consideration, we propose to convert image-based depth maps to pseudo-LiDAR representations --- essentially mimicking LiDAR signal. With this representation we can apply different existing LiDAR-based detection algorithms. On the popular KITTI benchmark, our approach achieves impressive improvements over the existing state-of-the-art in image-based performance --- raising the detection accuracy of objects within 30m range from the previous state-of-the-art of 22% to an unprecedented 74%. At the time of submission our algorithm holds the highest entry on the KITTI 3D object detection leaderboard for stereo image based approaches.
33 |
34 | ## Usage
35 |
36 | ### 1. Overview
37 |
38 | We provide the guidance and codes to train stereo depth estimator and 3D object detector using the [KITTI object detection benchmark](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d). We also provide our pre-trained models.
39 |
40 | ### 2. Stereo depth estimation models
41 | We provide our pretrained [PSMNet](http://openaccess.thecvf.com/content_cvpr_2018/papers/Chang_Pyramid_Stereo_Matching_CVPR_2018_paper.pdf) model using the Scene Flow dataset and the 3,712 training images of the KITTI detection benchmark.
42 | - [Pretrained PSMNet](https://drive.google.com/file/d/1sWjsIO9Fuy92wT3gLkHF3PA7SP8QZBzu/view?usp=sharing)
43 |
44 | We also directly provide the pseudo-LiDAR point clouds and the ground planes of training and testing images estimated by this pre-trained model.
45 | - [training/pseudo-lidar_velodyne](https://drive.google.com/file/d/10txZOtKk_aY3B7AhHjJPMCiRf5pP62nV/view?usp=sharing)
46 | - [testing/pseudo-lidar_velodyne](https://drive.google.com/file/d/1XRAWYpMJeaVVXNN442xDgXnAa3pLBUvv/view?usp=sharing)
47 | - [training/pseudo-lidar_planes](https://drive.google.com/file/d/1NBN85o9Jl7FjV5HwldmBv_9T4LeoNiwV/view?usp=sharing)
48 | - [testing/pseudo-lidar_planes](https://drive.google.com/file/d/1G5_5VHbygssrKOzz1zEirNlKjVnMc5tz/view?usp=sharing)
49 |
50 | We also provide codes to train your own stereo depth estimator and prepare the point clouds and gound planes. **If you want to use our pseudo-LiDAR data for 3D object detection, you may skip the following contents and directly move on to object detection models.**
51 |
52 | #### 2.1 Dependencies
53 | - Python 3.5+
54 | - numpy, scikit-learn, scipy
55 | - KITTI 3D object detection dataset
56 |
57 | #### 2.2 Download the dataset
58 | You need to download the KITTI dataset from [here](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d), including left and right color images, Velodyne point clouds, camera calibration matrices, and training labels. You also need to download the image set files from [here](https://github.com/charlesq34/frustum-pointnets/tree/master/kitti/image_sets). Then you need to organize the data in the following way.
59 | ```angular2html
60 | KITTI/object/
61 |
62 | train.txt
63 | val.txt
64 | test.txt
65 |
66 | training/
67 | calib/
68 | image_2/ #left image
69 | image_3/ #right image
70 | label_2/
71 | velodyne/
72 |
73 | testing/
74 | calib/
75 | image_2/
76 | image_3/
77 | velodyne/
78 | ```
79 | The Velodyne point clouds (by LiDAR) are used **ONLY** as the ground truths to train a stereo depth estimator (e.g., PSMNet).
80 | #### 2.3 Generate ground-truth image disparities
81 | Use the script`./preprocessing/generate_disp.py` to process all velodyne files appeared in `train.txt`. This is our **training ground truth**. Or you can directly download them from [disparity](https://drive.google.com/file/d/1JqtPdYnajNhDNxucuQYmD-79rl7MIXoZ/view?usp=sharing). Name this folder as `disparity` and put it inside the `training` folder.
82 | ```angular2html
83 | python generate_disp.py --data_path ./KITTI/object/training/ --split_file ./KITTI/object/train.txt
84 | ```
85 |
86 | #### 2.4. Train the stereo model
87 | You can train any stereo disparity model as you want. Here we give an example to train the PSMNet. The modified code is saved in the subfolder `psmnet`. Make sure you follow the `README` inside this folder to install the correct python and library. I strongly suggest using `conda env` to organize the python environments since we will use Python with different versions. Download the psmnet model pretrained on Sceneflow dataset from [here](https://drive.google.com/file/d/1D-OcFbrQXNl3iSOeBnMBGd87pNXp0RT1/view?usp=sharing).
88 |
89 | ```python2html
90 | # train psmnet with 4 TITAN X GPUs.
91 | python ./psmnet/finetune_3d.py --maxdisp 192 \
92 | --model stackhourglass \
93 | --datapath ./KITTI/object/training/ \
94 | --split_file ./KITTI/object/train.txt \
95 | --epochs 300 \
96 | --lr_scale 50 \
97 | --loadmodel ./pretrained_sceneflow.tar \
98 | --savemodel ./psmnet/kitti_3d/ --btrain 12
99 | ```
100 |
101 | #### 2.5 Predict the point clouds
102 | ##### Predict the disparities.
103 | ```angular2html
104 | # training
105 | python ./psmnet/submission.py \
106 | --loadmodel ./psmnet/kitti_3d/finetune_300.tar \
107 | --datapath ./KITTI/object/training/ \
108 | --save_path ./KITTI/object/training/predict_disparity
109 | # testing
110 | python ./psmnet/submission.py \
111 | --loadmodel ./psmnet/kitti_3d/finetune_300.tar \
112 | --datapath ./KITTI/object/testing/ \
113 | --save_path ./KITTI/object/testing/predict_disparity
114 | ```
115 | ##### Convert the disparities to point clouds.
116 | ```angular2html
117 | # training
118 | python ./preprocessing/generate_lidar.py \
119 | --calib_dir ./KITTI/object/training/calib/ \
120 | --save_dir ./KITTI/object/training/pseudo-lidar_velodyne/ \
121 | --disparity_dir ./KITTI/object/training/predict_disparity \
122 | --max_high 1
123 | # testing
124 | python ./preprocessing/generate_lidar.py \
125 | --calib_dir ./KITTI/object/testing/calib/ \
126 | --save_dir ./KITTI/object/testing/pseudo-lidar_velodyne/ \
127 | --disparity_dir ./KITTI/object/testing/predict_disparity \
128 | --max_high 1
129 | ```
130 | If you want to generate point cloud from depth map (like DORN), you can add `--is_depth` in the command.
131 |
132 | #### 2.6 Generate ground plane
133 | If you want to train an [AVOD]( https://github.com/kujason/avod) model for 3D object detection, you need to generate ground planes from pseudo-lidar point clouds.
134 | ```angular2html
135 | #training
136 | python ./preprocessing/kitti_process_RANSAC.py \
137 | --calib ./KITTI/object/training/calib/ \
138 | --lidar_dir ./KITTI/object/training/pseudo-lidar_velodyne/ \
139 | --planes_dir /KITTI/object/training/pseudo-lidar_planes/
140 | #testing
141 | python ./preprocessing/kitti_process_RANSAC.py \
142 | --calib ./KITTI/object/testing/calib/ \
143 | --lidar_dir ./KITTI/object/testing/pseudo-lidar_velodyne/ \
144 | --planes_dir /KITTI/object/testing/pseudo-lidar_planes/
145 | ```
146 | ### 3. Object Detection models
147 | #### AVOD model
148 | Download the code from [https://github.com/kujason/avod](https://github.com/kujason/avod) and install the Python dependencies.
149 |
150 | Follow their README to prepare the data and then replace (1) files in `velodyne` with those in `pseudo-lidar_velodyne` and (2) files in `planes` with those in `pseudo-lidar_planes`. Note that you should still keep the folder names as `velodyne` and `planes`.
151 |
152 | Follow their README to train the `pyramid_cars_with_aug_example` model. You can also download our pretrained model and directly evaluate on it. But if you want to submit your result to the leaderboard, you need to train it on `trainval.txt`.
153 |
154 | - [pretrained AVOD](https://drive.google.com/file/d/1wuMykUDx8tcCfxpqnprmzrgUyheQV42F/view?usp=sharing) (trained only on train.txt)
155 |
156 |
157 |
158 | #### Frustum-PointNets model
159 | Download the code from [https://github.com/charlesq34/frustum-pointnets](https://github.com/charlesq34/frustum-pointnets) and install the Python dependencies.
160 |
161 | Follow their README to prepare the data and then replace files in `velodyne` with those in `pseudo-lidar_velodyne`. Note that you should still keep the folder name as `velodyne`.
162 |
163 | Follow their README to train the v1 model. You can also download our pretrained model and directly evaluate on it.
164 |
165 | - [pretrained Frustum_V1](https://drive.google.com/file/d/1qhCxw6uHqQ4SAkxIuBi-QCKqLmTGiNhP/view?usp=sharing) (trained only on train.txt)
166 |
167 | ## Results
168 | The main results on the validation dataset of our pseudo-LiDAR method.
169 | 
170 |
171 | You can download the avod validation results from [HERE](https://drive.google.com/file/d/13nOhBCmj8rzjMHDEw3syROuqHsoxWIKJ/view?usp=sharing).
172 |
173 |
174 | ## Contact
175 | If you have any question, please feel free to email us.
176 |
177 | Yan Wang (yw763@cornell.edu), Harry Chao(weilunchao760414@gmail.com), Div Garg(dg595@cornell.edu)
178 |
--------------------------------------------------------------------------------
/docs/CVPR_2019_poster.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mileyan/pseudo_lidar/032c7a0d73c3fdf84e934af3f57f8eb489a52906/docs/CVPR_2019_poster.jpg
--------------------------------------------------------------------------------
/docs/CVPR_2019_poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mileyan/pseudo_lidar/032c7a0d73c3fdf84e934af3f57f8eb489a52906/docs/CVPR_2019_poster.pdf
--------------------------------------------------------------------------------
/docs/cvpr2018-pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mileyan/pseudo_lidar/032c7a0d73c3fdf84e934af3f57f8eb489a52906/docs/cvpr2018-pipeline.png
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
11 |
12 | Pseudo-LiDAR from Visual Depth Estimation: Bridging the Gap in 3D Object Detection for Autonomous Driving
13 |
14 |
15 |
16 |
17 | Pseudo-LiDAR from Visual Depth Estimation: Bridging the Gap in 3D Object Detection for Autonomous Driving
18 | Yan Wang, Wei-Lun Chao, Divyansh Garg, Bharath Hariharan, Mark Campbell, Kilian Q. Weinberger
19 | Cornell University, Ithaca, NY
20 |
21 |
22 |
23 | VIDEO
24 |
25 |
26 |
27 |
28 |
29 | Abstract:
30 |
31 |
32 |
33 | 3D object detection is an essential task in autonomous
34 | driving. Recent techniques excel with highly accurate detection
35 | rates, provided the 3D input data is obtained from
36 | precise but expensive LiDAR technology. Approaches based
37 | on cheaper monocular or stereo imagery data have, until
38 | now, resulted in drastically lower accuracies — a gap that is
39 | commonly attributed to poor image-based depth estimation.
40 | However, in this paper we argue that data representation
41 | (rather than its quality) accounts for the majority of the difference.
42 | Taking the inner workings of convolutional neural
43 | networks into consideration, we propose to convert image-based
44 | depth maps to pseudo-LiDAR representations — essentially
45 | mimicking LiDAR signal. With this representation
46 | we can apply different existing LiDAR-based detection algorithms.
47 | On the popular KITTI benchmark, our approach
48 | achieves impressive improvements over the existing state-of-the-art
49 | in image-based performance — raising the detection
50 | accuracy of objects within 30m range from the previous
51 | state-of-the-art of 22% to an unprecedented 74% . At
52 | the time of submission our algorithm holds the highest entry
53 | on the KITTI 3D object detection leaderboard for stereo
54 | image based approaches.
55 |
56 |
57 |
58 |
59 |
60 |
Architecture:
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
Experiment Results:
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
Paper:
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
Poster:
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 | Citation:
108 |
109 |
110 | @article{wang2018pseudo,
111 | title={Pseudo-LiDAR from Visual Depth Estimation: Bridging the Gap in 3D Object Detection for Autonomous Driving},
112 | author={Wang, Yan and Chao, Wei-Lun and Garg, Divyansh and Hariharan, Bharath and Campbell, Mark and Weinberger, Kilian Q.},
113 | journal={arXiv preprint arXiv:1812.07179},
114 | year={2018}
115 | }
116 |
117 |
118 |
119 |
120 |
121 |
122 |
125 |
128 |
131 |
132 |
133 |
134 |
135 |
--------------------------------------------------------------------------------
/docs/main_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mileyan/pseudo_lidar/032c7a0d73c3fdf84e934af3f57f8eb489a52906/docs/main_result.png
--------------------------------------------------------------------------------
/docs/pdf_thumbnail.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mileyan/pseudo_lidar/032c7a0d73c3fdf84e934af3f57f8eb489a52906/docs/pdf_thumbnail.jpg
--------------------------------------------------------------------------------
/docs/table1_caption.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mileyan/pseudo_lidar/032c7a0d73c3fdf84e934af3f57f8eb489a52906/docs/table1_caption.png
--------------------------------------------------------------------------------
/figures/cvpr2018-pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mileyan/pseudo_lidar/032c7a0d73c3fdf84e934af3f57f8eb489a52906/figures/cvpr2018-pipeline.png
--------------------------------------------------------------------------------
/figures/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mileyan/pseudo_lidar/032c7a0d73c3fdf84e934af3f57f8eb489a52906/figures/result.png
--------------------------------------------------------------------------------
/preprocessing/generate_disp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import numpy as np
5 | import scipy.misc as ssc
6 |
7 | import kitti_util
8 |
9 |
10 | def generate_dispariy_from_velo(pc_velo, height, width, calib):
11 | pts_2d = calib.project_velo_to_image(pc_velo)
12 | fov_inds = (pts_2d[:, 0] < width - 1) & (pts_2d[:, 0] >= 0) & \
13 | (pts_2d[:, 1] < height - 1) & (pts_2d[:, 1] >= 0)
14 | fov_inds = fov_inds & (pc_velo[:, 0] > 2)
15 | imgfov_pc_velo = pc_velo[fov_inds, :]
16 | imgfov_pts_2d = pts_2d[fov_inds, :]
17 | imgfov_pc_rect = calib.project_velo_to_rect(imgfov_pc_velo)
18 | depth_map = np.zeros((height, width)) - 1
19 | imgfov_pts_2d = np.round(imgfov_pts_2d).astype(int)
20 | for i in range(imgfov_pts_2d.shape[0]):
21 | depth = imgfov_pc_rect[i, 2]
22 | depth_map[int(imgfov_pts_2d[i, 1]), int(imgfov_pts_2d[i, 0])] = depth
23 | baseline = 0.54
24 |
25 | disp_map = (calib.f_u * baseline) / depth_map
26 | return disp_map
27 |
28 |
29 | if __name__ == '__main__':
30 | parser = argparse.ArgumentParser(description='Generate Disparity')
31 | parser.add_argument('--data_path', type=str, default='~/Kitti/object/training/')
32 | parser.add_argument('--split_file', type=str, default='~/Kitti/object/train.txt')
33 | args = parser.parse_args()
34 |
35 | assert os.path.isdir(args.data_path)
36 | lidar_dir = args.data_path + '/velodyne/'
37 | calib_dir = args.data_path + '/calib/'
38 | image_dir = args.data_path + '/image_2/'
39 | disparity_dir = args.data_path + '/disparity/'
40 |
41 | assert os.path.isdir(lidar_dir)
42 | assert os.path.isdir(calib_dir)
43 | assert os.path.isdir(image_dir)
44 |
45 | if not os.path.isdir(disparity_dir):
46 | os.makedirs(disparity_dir)
47 |
48 | lidar_files = [x for x in os.listdir(lidar_dir) if x[-3:] == 'bin']
49 | lidar_files = sorted(lidar_files)
50 |
51 | assert os.path.isfile(args.split_file)
52 | with open(args.split_file, 'r') as f:
53 | file_names = [x.strip() for x in f.readlines()]
54 |
55 | for fn in lidar_files:
56 | predix = fn[:-4]
57 | if predix not in file_names:
58 | continue
59 | calib_file = '{}/{}.txt'.format(calib_dir, predix)
60 | calib = kitti_util.Calibration(calib_file)
61 | # load point cloud
62 | lidar = np.fromfile(lidar_dir + '/' + fn, dtype=np.float32).reshape((-1, 4))[:, :3]
63 | image_file = '{}/{}.png'.format(image_dir, predix)
64 | image = ssc.imread(image_file)
65 | height, width = image.shape[:2]
66 | disp = generate_dispariy_from_velo(lidar, height, width, calib)
67 | np.save(disparity_dir + '/' + predix, disp)
68 | print('Finish Disparity {}'.format(predix))
69 |
--------------------------------------------------------------------------------
/preprocessing/generate_lidar.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import numpy as np
5 | import scipy.misc as ssc
6 |
7 | import kitti_util
8 |
9 |
10 | def project_disp_to_points(calib, disp, max_high):
11 | disp[disp < 0] = 0
12 | baseline = 0.54
13 | mask = disp > 0
14 | depth = calib.f_u * baseline / (disp + 1. - mask)
15 | rows, cols = depth.shape
16 | c, r = np.meshgrid(np.arange(cols), np.arange(rows))
17 | points = np.stack([c, r, depth])
18 | points = points.reshape((3, -1))
19 | points = points.T
20 | points = points[mask.reshape(-1)]
21 | cloud = calib.project_image_to_velo(points)
22 | valid = (cloud[:, 0] >= 0) & (cloud[:, 2] < max_high)
23 | return cloud[valid]
24 |
25 | def project_depth_to_points(calib, depth, max_high):
26 | rows, cols = depth.shape
27 | c, r = np.meshgrid(np.arange(cols), np.arange(rows))
28 | points = np.stack([c, r, depth])
29 | points = points.reshape((3, -1))
30 | points = points.T
31 | cloud = calib.project_image_to_velo(points)
32 | valid = (cloud[:, 0] >= 0) & (cloud[:, 2] < max_high)
33 | return cloud[valid]
34 |
35 | if __name__ == '__main__':
36 | parser = argparse.ArgumentParser(description='Generate Libar')
37 | parser.add_argument('--calib_dir', type=str,
38 | default='~/Kitti/object/training/calib')
39 | parser.add_argument('--disparity_dir', type=str,
40 | default='~/Kitti/object/training/predicted_disparity')
41 | parser.add_argument('--save_dir', type=str,
42 | default='~/Kitti/object/training/predicted_velodyne')
43 | parser.add_argument('--max_high', type=int, default=1)
44 | parser.add_argument('--is_depth', action='store_true')
45 |
46 | args = parser.parse_args()
47 |
48 | assert os.path.isdir(args.disparity_dir)
49 | assert os.path.isdir(args.calib_dir)
50 |
51 | if not os.path.isdir(args.save_dir):
52 | os.makedirs(args.save_dir)
53 |
54 | disps = [x for x in os.listdir(args.disparity_dir) if x[-3:] == 'png' or x[-3:] == 'npy']
55 | disps = sorted(disps)
56 |
57 | for fn in disps:
58 | predix = fn[:-4]
59 | calib_file = '{}/{}.txt'.format(args.calib_dir, predix)
60 | calib = kitti_util.Calibration(calib_file)
61 | # disp_map = ssc.imread(args.disparity_dir + '/' + fn) / 256.
62 | if fn[-3:] == 'png':
63 | disp_map = ssc.imread(args.disparity_dir + '/' + fn)
64 | elif fn[-3:] == 'npy':
65 | disp_map = np.load(args.disparity_dir + '/' + fn)
66 | else:
67 | assert False
68 | if not args.is_depth:
69 | disp_map = (disp_map*256).astype(np.uint16)/256.
70 | lidar = project_disp_to_points(calib, disp_map, args.max_high)
71 | else:
72 | disp_map = (disp_map).astype(np.float32)/256.
73 | lidar = project_depth_to_points(calib, disp_map, args.max_high)
74 | # pad 1 in the indensity dimension
75 | lidar = np.concatenate([lidar, np.ones((lidar.shape[0], 1))], 1)
76 | lidar = lidar.astype(np.float32)
77 | lidar.tofile('{}/{}.bin'.format(args.save_dir, predix))
78 | print('Finish Depth {}'.format(predix))
79 |
--------------------------------------------------------------------------------
/preprocessing/kitti_process_RANSAC.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import numpy as np
5 | from sklearn.linear_model import RANSACRegressor
6 |
7 | import kitti_util as utils
8 |
9 |
10 | def extract_ransac(calib_dir, lidar_dir, planes_dir):
11 | data_idx_list = [x[:-4] for x in os.listdir(lidar_dir) if x[-4:] == '.bin']
12 |
13 | if not os.path.isdir(planes_dir):
14 | os.makedirs(planes_dir)
15 |
16 | for data_idx in data_idx_list:
17 |
18 | print('------------- ', data_idx)
19 | calib = calib_dir + '/' + data_idx + '.txt'
20 | calib = utils.Calibration(calib)
21 | pc_velo = lidar_dir + '/' + data_idx + '.bin'
22 | pc_velo = np.fromfile(pc_velo, dtype=np.float32).reshape(-1, 4)
23 | pc_rect = calib.project_velo_to_rect(pc_velo[:, :3])
24 | valid_loc = (pc_rect[:, 1] > 1.5) & \
25 | (pc_rect[:, 1] < 1.86) & \
26 | (pc_rect[:, 2] > 0) & \
27 | (pc_rect[:, 2] < 40) & \
28 | (pc_rect[:, 0] > -15) & \
29 | (pc_rect[:, 0] < 15)
30 | pc_rect = pc_rect[valid_loc]
31 | if len(pc_rect) < 1:
32 | w = [0, -1, 0]
33 | h = 1.65
34 | else:
35 | reg = RANSACRegressor().fit(pc_rect[:, [0, 2]], pc_rect[:, 1])
36 | w = np.zeros(3)
37 | w[0] = reg.estimator_.coef_[0]
38 | w[2] = reg.estimator_.coef_[1]
39 | w[1] = -1.0
40 | h = reg.estimator_.intercept_
41 | w = w / np.linalg.norm(w)
42 | print(w)
43 | print(h)
44 |
45 | lines = ['# Plane', 'Width 4', 'Height 1']
46 |
47 | plane_file = os.path.join(planes_dir, data_idx + '.txt')
48 | result_lines = lines[:3]
49 | result_lines.append("{:e} {:e} {:e} {:e}".format(w[0], w[1], w[2], h))
50 | result_str = '\n'.join(result_lines)
51 | with open(plane_file, 'w') as f:
52 | f.write(result_str)
53 |
54 |
55 | if __name__ == '__main__':
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument('--calib_dir', default='KITTI/object/training/calib')
58 | parser.add_argument('--lidar_dir', default='KITTI/object/training/velodyne')
59 | parser.add_argument('--planes_dir', default='KITTI/object/training/velodyne_planes')
60 | args = parser.parse_args()
61 |
62 | extract_ransac(args.calib_dir, args.lidar_dir, args.planes_dir)
63 |
--------------------------------------------------------------------------------
/preprocessing/kitti_util.py:
--------------------------------------------------------------------------------
1 | """ Helper methods for loading and parsing KITTI data.
2 |
3 | Author: Charles R. Qi
4 | Date: September 2017
5 | """
6 | from __future__ import print_function
7 |
8 | import numpy as np
9 |
10 |
11 | class Calibration(object):
12 | ''' Calibration matrices and utils
13 | 3d XYZ in .txt are in rect camera coord.
14 | 2d box xy are in image2 coord
15 | Points in .bin are in Velodyne coord.
16 |
17 | y_image2 = P^2_rect * x_rect
18 | y_image2 = P^2_rect * R0_rect * Tr_velo_to_cam * x_velo
19 | x_ref = Tr_velo_to_cam * x_velo
20 | x_rect = R0_rect * x_ref
21 |
22 | P^2_rect = [f^2_u, 0, c^2_u, -f^2_u b^2_x;
23 | 0, f^2_v, c^2_v, -f^2_v b^2_y;
24 | 0, 0, 1, 0]
25 | = K * [1|t]
26 |
27 | image2 coord:
28 | ----> x-axis (u)
29 | |
30 | |
31 | v y-axis (v)
32 |
33 | velodyne coord:
34 | front x, left y, up z
35 |
36 | rect/ref camera coord:
37 | right x, down y, front z
38 |
39 | Ref (KITTI paper): http://www.cvlibs.net/publications/Geiger2013IJRR.pdf
40 |
41 | TODO(rqi): do matrix multiplication only once for each projection.
42 | '''
43 |
44 | def __init__(self, calib_filepath):
45 |
46 | calibs = self.read_calib_file(calib_filepath)
47 | # Projection matrix from rect camera coord to image2 coord
48 | self.P = calibs['P2']
49 | self.P = np.reshape(self.P, [3, 4])
50 | # Rigid transform from Velodyne coord to reference camera coord
51 | self.V2C = calibs['Tr_velo_to_cam']
52 | self.V2C = np.reshape(self.V2C, [3, 4])
53 | self.C2V = inverse_rigid_trans(self.V2C)
54 | # Rotation from reference camera coord to rect camera coord
55 | self.R0 = calibs['R0_rect']
56 | self.R0 = np.reshape(self.R0, [3, 3])
57 |
58 | # Camera intrinsics and extrinsics
59 | self.c_u = self.P[0, 2]
60 | self.c_v = self.P[1, 2]
61 | self.f_u = self.P[0, 0]
62 | self.f_v = self.P[1, 1]
63 | self.b_x = self.P[0, 3] / (-self.f_u) # relative
64 | self.b_y = self.P[1, 3] / (-self.f_v)
65 |
66 | def read_calib_file(self, filepath):
67 | ''' Read in a calibration file and parse into a dictionary.
68 | Ref: https://github.com/utiasSTARS/pykitti/blob/master/pykitti/utils.py
69 | '''
70 | data = {}
71 | with open(filepath, 'r') as f:
72 | for line in f.readlines():
73 | line = line.rstrip()
74 | if len(line) == 0: continue
75 | key, value = line.split(':', 1)
76 | # The only non-float values in these files are dates, which
77 | # we don't care about anyway
78 | try:
79 | data[key] = np.array([float(x) for x in value.split()])
80 | except ValueError:
81 | pass
82 |
83 | return data
84 |
85 | def cart2hom(self, pts_3d):
86 | ''' Input: nx3 points in Cartesian
87 | Oupput: nx4 points in Homogeneous by pending 1
88 | '''
89 | n = pts_3d.shape[0]
90 | pts_3d_hom = np.hstack((pts_3d, np.ones((n, 1))))
91 | return pts_3d_hom
92 |
93 | # ===========================
94 | # ------- 3d to 3d ----------
95 | # ===========================
96 | def project_velo_to_ref(self, pts_3d_velo):
97 | pts_3d_velo = self.cart2hom(pts_3d_velo) # nx4
98 | return np.dot(pts_3d_velo, np.transpose(self.V2C))
99 |
100 | def project_ref_to_velo(self, pts_3d_ref):
101 | pts_3d_ref = self.cart2hom(pts_3d_ref) # nx4
102 | return np.dot(pts_3d_ref, np.transpose(self.C2V))
103 |
104 | def project_rect_to_ref(self, pts_3d_rect):
105 | ''' Input and Output are nx3 points '''
106 | return np.transpose(np.dot(np.linalg.inv(self.R0), np.transpose(pts_3d_rect)))
107 |
108 | def project_ref_to_rect(self, pts_3d_ref):
109 | ''' Input and Output are nx3 points '''
110 | return np.transpose(np.dot(self.R0, np.transpose(pts_3d_ref)))
111 |
112 | def project_rect_to_velo(self, pts_3d_rect):
113 | ''' Input: nx3 points in rect camera coord.
114 | Output: nx3 points in velodyne coord.
115 | '''
116 | pts_3d_ref = self.project_rect_to_ref(pts_3d_rect)
117 | return self.project_ref_to_velo(pts_3d_ref)
118 |
119 | def project_velo_to_rect(self, pts_3d_velo):
120 | pts_3d_ref = self.project_velo_to_ref(pts_3d_velo)
121 | return self.project_ref_to_rect(pts_3d_ref)
122 |
123 | # ===========================
124 | # ------- 3d to 2d ----------
125 | # ===========================
126 | def project_rect_to_image(self, pts_3d_rect):
127 | ''' Input: nx3 points in rect camera coord.
128 | Output: nx2 points in image2 coord.
129 | '''
130 | pts_3d_rect = self.cart2hom(pts_3d_rect)
131 | pts_2d = np.dot(pts_3d_rect, np.transpose(self.P)) # nx3
132 | pts_2d[:, 0] /= pts_2d[:, 2]
133 | pts_2d[:, 1] /= pts_2d[:, 2]
134 | return pts_2d[:, 0:2]
135 |
136 | def project_velo_to_image(self, pts_3d_velo):
137 | ''' Input: nx3 points in velodyne coord.
138 | Output: nx2 points in image2 coord.
139 | '''
140 | pts_3d_rect = self.project_velo_to_rect(pts_3d_velo)
141 | return self.project_rect_to_image(pts_3d_rect)
142 |
143 | # ===========================
144 | # ------- 2d to 3d ----------
145 | # ===========================
146 | def project_image_to_rect(self, uv_depth):
147 | ''' Input: nx3 first two channels are uv, 3rd channel
148 | is depth in rect camera coord.
149 | Output: nx3 points in rect camera coord.
150 | '''
151 | n = uv_depth.shape[0]
152 | x = ((uv_depth[:, 0] - self.c_u) * uv_depth[:, 2]) / self.f_u + self.b_x
153 | y = ((uv_depth[:, 1] - self.c_v) * uv_depth[:, 2]) / self.f_v + self.b_y
154 | pts_3d_rect = np.zeros((n, 3))
155 | pts_3d_rect[:, 0] = x
156 | pts_3d_rect[:, 1] = y
157 | pts_3d_rect[:, 2] = uv_depth[:, 2]
158 | return pts_3d_rect
159 |
160 | def project_image_to_velo(self, uv_depth):
161 | pts_3d_rect = self.project_image_to_rect(uv_depth)
162 | return self.project_rect_to_velo(pts_3d_rect)
163 |
164 |
165 | def inverse_rigid_trans(Tr):
166 | ''' Inverse a rigid body transform matrix (3x4 as [R|t])
167 | [R'|-R't; 0|1]
168 | '''
169 | inv_Tr = np.zeros_like(Tr) # 3x4
170 | inv_Tr[0:3, 0:3] = np.transpose(Tr[0:3, 0:3])
171 | inv_Tr[0:3, 3] = np.dot(-np.transpose(Tr[0:3, 0:3]), Tr[0:3, 3])
172 | return inv_Tr
173 |
--------------------------------------------------------------------------------
/psmnet/README.md:
--------------------------------------------------------------------------------
1 | # Pyramid Stereo Matching Network
2 |
3 | This repository contains the code (in PyTorch) for "[Pyramid Stereo Matching Network](https://arxiv.org/abs/1803.08669)" paper (CVPR 2018) by [Jia-Ren Chang](https://jiarenchang.github.io/) and [Yong-Sheng Chen](https://people.cs.nctu.edu.tw/~yschen/).
4 |
5 | ### Citation
6 | ```
7 | @inproceedings{chang2018pyramid,
8 | title={Pyramid Stereo Matching Network},
9 | author={Chang, Jia-Ren and Chen, Yong-Sheng},
10 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
11 | pages={5410--5418},
12 | year={2018}
13 | }
14 | ```
15 |
16 | ## Contents
17 |
18 | 1. [Introduction](#introduction)
19 | 2. [Usage](#usage)
20 | 3. [Results](#results)
21 | 4. [Contacts](#contacts)
22 |
23 | ## Introduction
24 |
25 | Recent work has shown that depth estimation from a stereo pair of images can be formulated as a supervised learning task to be resolved with convolutional neural networks (CNNs). However, current architectures rely on patch-based Siamese networks, lacking the means to exploit context information for finding correspondence in illposed regions. To tackle this problem, we propose PSMNet, a pyramid stereo matching network consisting of two main modules: spatial pyramid pooling and 3D CNN. The spatial pyramid pooling module takes advantage of the capacity of global context information by aggregating context in different scales and locations to form a cost volume. The 3D CNN learns to regularize cost volume using stacked multiple hourglass networks in conjunction with intermediate supervision.
26 |
27 |
28 |
29 | ## Usage
30 |
31 | ### Dependencies
32 |
33 | - [Python2.7](https://www.python.org/downloads/)
34 | - [PyTorch(0.4.0+)](http://pytorch.org)
35 | - torchvision 0.2.0 (higher version may cause issues)
36 | - [KITTI Stereo](http://www.cvlibs.net/datasets/kitti/eval_stereo.php)
37 | - [Scene Flow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
38 |
39 | ```
40 | Usage of Scene Flow dataset
41 | Download RGB cleanpass images and its disparity for three subset: FlyingThings3D, Driving, and Monkaa.
42 | Put them in the same folder.
43 | And rename the folder as: "driving_frames_cleanpass", "driving_disparity", "monkaa_frames_cleanpass", "monkaa_disparity", "frames_cleanpass", "frames_disparity".
44 | ```
45 |
46 | ### Train
47 | As an example, use the following command to train a PSMNet on Scene Flow
48 |
49 | ```
50 | python main.py --maxdisp 192 \
51 | --model stackhourglass \
52 | --datapath (your scene flow data folder)\
53 | --epochs 10 \
54 | --loadmodel (optional)\
55 | --savemodel (path for saving model)
56 | ```
57 |
58 | As another example, use the following command to finetune a PSMNet on KITTI 2015
59 |
60 | ```
61 | python finetune.py --maxdisp 192 \
62 | --model stackhourglass \
63 | --datatype 2015 \
64 | --datapath (KITTI 2015 training data folder) \
65 | --epochs 300 \
66 | --loadmodel (pretrained PSMNet) \
67 | --savemodel (path for saving model)
68 | ```
69 | You can also see those examples in run.sh.
70 |
71 | ### Evaluation
72 | Use the following command to evaluate the trained PSMNet on KITTI 2015 test data
73 |
74 | ```
75 | python submission.py --maxdisp 192 \
76 | --model stackhourglass \
77 | --KITTI 2015 \
78 | --datapath (KITTI 2015 test data folder) \
79 | --loadmodel (finetuned PSMNet) \
80 | ```
81 |
82 | ### Pretrained Model
83 | ※NOTE: The pretrained model were saved in .tar; however, you don't need to untar it. Use torch.load() to load it.
84 |
85 | Update: 2018/9/6 We released the pre-trained KITTI 2012 model.
86 |
87 | | KITTI 2015 | Scene Flow | KITTI 2012|
88 | |---|---|---|
89 | |[Google Drive](https://drive.google.com/file/d/1pHWjmhKMG4ffCrpcsp_MTXMJXhgl3kF9/view?usp=sharing)|[Google Drive](https://drive.google.com/file/d/1xoqkQ2NXik1TML_FMUTNZJFAHrhLdKZG/view?usp=sharing)|[Google Drive](https://drive.google.com/file/d/1p4eJ2xDzvQxaqB20A_MmSP9-KORBX1pZ/view)|
90 |
91 |
92 | ## Results
93 |
94 | ### Evaluation of PSMNet with different settings
95 |
96 |
97 | ※Note that the reported 3-px validation errors were calculated using KITTI's official matlab code, not our code.
98 |
99 | ### Results on KITTI 2015 leaderboard
100 | [Leaderboard Link](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo)
101 |
102 | | Method | D1-all (All) | D1-all (Noc)| Runtime (s) |
103 | |---|---|---|---|
104 | | PSMNet | 2.32 % | 2.14 % | 0.41 |
105 | | [iResNet-i2](https://arxiv.org/abs/1712.01039) | 2.44 % | 2.19 % | 0.12 |
106 | | [GC-Net](https://arxiv.org/abs/1703.04309) | 2.87 % | 2.61 % | 0.90 |
107 | | [MC-CNN](https://github.com/jzbontar/mc-cnn) | 3.89 % | 3.33 % | 67 |
108 |
109 | ### Qualitative results
110 | #### Left image
111 |
112 |
113 | #### Predicted disparity
114 |
115 |
116 | #### Error
117 |
118 |
119 | ### Visualization of Receptive Field
120 | We visualize the receptive fields of different settings of PSMNet, full setting and baseline.
121 |
122 | Full setting: dilated conv, SPP, stacked hourglass
123 |
124 | Baseline: no dilated conv, no SPP, no stacked hourglass
125 |
126 | The receptive fields were calculated for the pixel at image center, indicated by the red cross.
127 |
128 |
129 |
130 |
131 |
132 | ## Contacts
133 | followwar@gmail.com
134 |
135 | Any discussions or concerns are welcomed!
136 |
--------------------------------------------------------------------------------
/psmnet/dataloader/KITTILoader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data as data
4 | import torch
5 | import torchvision.transforms as transforms
6 | import random
7 | from PIL import Image, ImageOps
8 | import numpy as np
9 | import preprocess
10 |
11 | IMG_EXTENSIONS = [
12 | '.jpg', '.JPG', '.jpeg', '.JPEG',
13 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
14 | ]
15 |
16 | def is_image_file(filename):
17 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
18 |
19 | def default_loader(path):
20 | return Image.open(path).convert('RGB')
21 |
22 | def disparity_loader(path):
23 | return Image.open(path)
24 |
25 |
26 | class myImageFloder(data.Dataset):
27 | def __init__(self, left, right, left_disparity, training, loader=default_loader, dploader= disparity_loader):
28 |
29 | self.left = left
30 | self.right = right
31 | self.disp_L = left_disparity
32 | self.loader = loader
33 | self.dploader = dploader
34 | self.training = training
35 |
36 | def __getitem__(self, index):
37 | left = self.left[index]
38 | right = self.right[index]
39 | disp_L= self.disp_L[index]
40 |
41 | left_img = self.loader(left)
42 | right_img = self.loader(right)
43 | dataL = self.dploader(disp_L)
44 |
45 |
46 | if self.training:
47 | w, h = left_img.size
48 | th, tw = 256, 512
49 |
50 | x1 = random.randint(0, w - tw)
51 | y1 = random.randint(0, h - th)
52 |
53 | left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
54 | right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))
55 |
56 | dataL = np.ascontiguousarray(dataL,dtype=np.float32)/256
57 | dataL = dataL[y1:y1 + th, x1:x1 + tw]
58 |
59 | processed = preprocess.get_transform(augment=False)
60 | left_img = processed(left_img)
61 | right_img = processed(right_img)
62 |
63 | return left_img, right_img, dataL
64 | else:
65 | w, h = left_img.size
66 |
67 | left_img = left_img.crop((w-1232, h-368, w, h))
68 | right_img = right_img.crop((w-1232, h-368, w, h))
69 | w1, h1 = left_img.size
70 |
71 | dataL = dataL.crop((w-1232, h-368, w, h))
72 | dataL = np.ascontiguousarray(dataL,dtype=np.float32)/256
73 |
74 | processed = preprocess.get_transform(augment=False)
75 | left_img = processed(left_img)
76 | right_img = processed(right_img)
77 |
78 | return left_img, right_img, dataL
79 |
80 | def __len__(self):
81 | return len(self.left)
82 |
--------------------------------------------------------------------------------
/psmnet/dataloader/KITTILoader3D.py:
--------------------------------------------------------------------------------
1 | IMG_EXTENSIONS = [
2 | '.jpg', '.JPG', '.jpeg', '.JPEG',
3 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
4 | ]
5 |
6 |
7 | def is_image_file(filename):
8 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
9 |
10 |
11 | def dataloader(filepath, train_file):
12 | left_fold = 'image_2/'
13 | right_fold = 'image_3/'
14 | disp_L = 'disparity/'
15 |
16 | with open(train_file, 'r') as f:
17 | train_idx = [x.strip() for x in f.readlines()]
18 |
19 | left_train = [filepath + '/' + left_fold + img + '.png' for img in train_idx]
20 | right_train = [filepath + '/' + right_fold + img + '.png' for img in train_idx]
21 | disp_train_L = [filepath + '/' + disp_L + img + '.npy' for img in train_idx]
22 |
23 | return left_train, right_train, disp_train_L
24 |
--------------------------------------------------------------------------------
/psmnet/dataloader/KITTILoader_dataset3d.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import preprocess
5 | import torch
6 | import torch.utils.data as data
7 | from PIL import Image
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 |
19 | def default_loader(path):
20 | return Image.open(path).convert('RGB')
21 |
22 |
23 | def disparity_loader(path):
24 | return np.load(path).astype(np.float32)
25 |
26 |
27 | class myImageFloder(data.Dataset):
28 | def __init__(self, left, right, left_disparity, training, loader=default_loader, dploader=disparity_loader):
29 |
30 | self.left = left
31 | self.right = right
32 | self.disp_L = left_disparity
33 | self.loader = loader
34 | self.dploader = dploader
35 | self.training = training
36 |
37 | def __getitem__(self, index):
38 | left = self.left[index]
39 | right = self.right[index]
40 | disp_L = self.disp_L[index]
41 |
42 | left_img = self.loader(left)
43 | right_img = self.loader(right)
44 | dataL = self.dploader(disp_L)
45 |
46 | if self.training:
47 | w, h = left_img.size
48 | th, tw = 256, 512
49 |
50 | x1 = random.randint(0, w - tw)
51 | y1 = random.randint(0, h - th)
52 |
53 | left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
54 | right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))
55 |
56 | dataL = dataL[y1:y1 + th, x1:x1 + tw]
57 |
58 | processed = preprocess.get_transform(augment=False)
59 | left_img = processed(left_img)
60 | right_img = processed(right_img)
61 |
62 | else:
63 | w, h = left_img.size
64 |
65 | # left_img = left_img.crop((w - 1232, h - 368, w, h))
66 | # right_img = right_img.crop((w - 1232, h - 368, w, h))
67 | left_img = left_img.crop((w - 1200, h - 352, w, h))
68 | right_img = right_img.crop((w - 1200, h - 352, w, h))
69 | w1, h1 = left_img.size
70 |
71 | # dataL1 = dataL[h - 368:h, w - 1232:w]
72 | dataL = dataL[h - 352:h, w - 1200:w]
73 |
74 | processed = preprocess.get_transform(augment=False)
75 | left_img = processed(left_img)
76 | right_img = processed(right_img)
77 |
78 | dataL = torch.from_numpy(dataL).float()
79 | return left_img, right_img, dataL
80 |
81 | def __len__(self):
82 | return len(self.left)
83 |
--------------------------------------------------------------------------------
/psmnet/dataloader/KITTI_submission_loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 |
19 | left_fold = 'image_2/'
20 | right_fold = 'image_3/'
21 | # left_fold = 'image_2/data/'
22 | # right_fold = 'image_3/data/'
23 |
24 |
25 | # image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
26 | image = [img for img in os.listdir(filepath+left_fold)]
27 | image = sorted(image)
28 |
29 |
30 | left_test = [filepath+left_fold+img for img in image]
31 | right_test = [filepath+right_fold+img for img in image]
32 |
33 | return left_test, right_test
34 |
--------------------------------------------------------------------------------
/psmnet/dataloader/KITTI_submission_loader2012.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 |
19 | left_fold = 'colored_0/'
20 | right_fold = 'colored_1/'
21 |
22 |
23 | image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
24 |
25 |
26 | left_test = [filepath+left_fold+img for img in image]
27 | right_test = [filepath+right_fold+img for img in image]
28 |
29 | return left_test, right_test
30 |
--------------------------------------------------------------------------------
/psmnet/dataloader/KITTIloader2012.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 |
19 | left_fold = 'colored_0/'
20 | right_fold = 'colored_1/'
21 | disp_noc = 'disp_occ/'
22 |
23 | image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
24 |
25 | train = image[:]
26 | val = image[160:]
27 |
28 | left_train = [filepath+left_fold+img for img in train]
29 | right_train = [filepath+right_fold+img for img in train]
30 | disp_train = [filepath+disp_noc+img for img in train]
31 |
32 |
33 | left_val = [filepath+left_fold+img for img in val]
34 | right_val = [filepath+right_fold+img for img in val]
35 | disp_val = [filepath+disp_noc+img for img in val]
36 |
37 | return left_train, right_train, disp_train, left_val, right_val, disp_val
38 |
--------------------------------------------------------------------------------
/psmnet/dataloader/KITTIloader2015.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 |
19 | left_fold = 'image_2/'
20 | right_fold = 'image_3/'
21 | disp_L = 'disp_occ_0/'
22 | disp_R = 'disp_occ_1/'
23 |
24 | image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
25 |
26 | train = image[:160]
27 | val = image[160:]
28 |
29 | left_train = [filepath+left_fold+img for img in train]
30 | right_train = [filepath+right_fold+img for img in train]
31 | disp_train_L = [filepath+disp_L+img for img in train]
32 | #disp_train_R = [filepath+disp_R+img for img in train]
33 |
34 | left_val = [filepath+left_fold+img for img in val]
35 | right_val = [filepath+right_fold+img for img in val]
36 | disp_val_L = [filepath+disp_L+img for img in val]
37 | #disp_val_R = [filepath+disp_R+img for img in val]
38 |
39 | return left_train, right_train, disp_train_L, left_val, right_val, disp_val_L
40 |
--------------------------------------------------------------------------------
/psmnet/dataloader/SecenFlowLoader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data as data
4 | import torch
5 | import torchvision.transforms as transforms
6 | import random
7 | from PIL import Image, ImageOps
8 | import preprocess
9 | import listflowfile as lt
10 | import readpfm as rp
11 | import numpy as np
12 |
13 | IMG_EXTENSIONS = [
14 | '.jpg', '.JPG', '.jpeg', '.JPEG',
15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16 | ]
17 |
18 | def is_image_file(filename):
19 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
20 |
21 | def default_loader(path):
22 | return Image.open(path).convert('RGB')
23 |
24 | def disparity_loader(path):
25 | return rp.readPFM(path)
26 |
27 |
28 | class myImageFloder(data.Dataset):
29 | def __init__(self, left, right, left_disparity, training, loader=default_loader, dploader= disparity_loader):
30 |
31 | self.left = left
32 | self.right = right
33 | self.disp_L = left_disparity
34 | self.loader = loader
35 | self.dploader = dploader
36 | self.training = training
37 |
38 | def __getitem__(self, index):
39 | left = self.left[index]
40 | right = self.right[index]
41 | disp_L= self.disp_L[index]
42 |
43 |
44 | left_img = self.loader(left)
45 | right_img = self.loader(right)
46 | dataL, scaleL = self.dploader(disp_L)
47 | dataL = np.ascontiguousarray(dataL,dtype=np.float32)
48 |
49 |
50 |
51 | if self.training:
52 | w, h = left_img.size
53 | th, tw = 256, 512
54 |
55 | x1 = random.randint(0, w - tw)
56 | y1 = random.randint(0, h - th)
57 |
58 | left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
59 | right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))
60 |
61 | dataL = dataL[y1:y1 + th, x1:x1 + tw]
62 |
63 | processed = preprocess.get_transform(augment=False)
64 | left_img = processed(left_img)
65 | right_img = processed(right_img)
66 |
67 | return left_img, right_img, dataL
68 | else:
69 | w, h = left_img.size
70 | left_img = left_img.crop((w-960, h-544, w, h))
71 | right_img = right_img.crop((w-960, h-544, w, h))
72 | processed = preprocess.get_transform(augment=False)
73 | left_img = processed(left_img)
74 | right_img = processed(right_img)
75 |
76 | return left_img, right_img, dataL
77 |
78 | def __len__(self):
79 | return len(self.left)
80 |
--------------------------------------------------------------------------------
/psmnet/dataloader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mileyan/pseudo_lidar/032c7a0d73c3fdf84e934af3f57f8eb489a52906/psmnet/dataloader/__init__.py
--------------------------------------------------------------------------------
/psmnet/dataloader/listflowfile.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 |
7 | IMG_EXTENSIONS = [
8 | '.jpg', '.JPG', '.jpeg', '.JPEG',
9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
10 | ]
11 |
12 |
13 | def is_image_file(filename):
14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
15 |
16 | def dataloader(filepath):
17 |
18 | classes = [d for d in os.listdir(filepath) if os.path.isdir(os.path.join(filepath, d))]
19 | image = [img for img in classes if img.find('frames_cleanpass') > -1]
20 | disp = [dsp for dsp in classes if dsp.find('disparity') > -1]
21 |
22 | monkaa_path = filepath + [x for x in image if 'monkaa' in x][0]
23 | monkaa_disp = filepath + [x for x in disp if 'monkaa' in x][0]
24 |
25 |
26 | monkaa_dir = os.listdir(monkaa_path)
27 |
28 | all_left_img=[]
29 | all_right_img=[]
30 | all_left_disp = []
31 | test_left_img=[]
32 | test_right_img=[]
33 | test_left_disp = []
34 |
35 |
36 | for dd in monkaa_dir:
37 | for im in os.listdir(monkaa_path+'/'+dd+'/left/'):
38 | if is_image_file(monkaa_path+'/'+dd+'/left/'+im):
39 | all_left_img.append(monkaa_path+'/'+dd+'/left/'+im)
40 | all_left_disp.append(monkaa_disp+'/'+dd+'/left/'+im.split(".")[0]+'.pfm')
41 |
42 | for im in os.listdir(monkaa_path+'/'+dd+'/right/'):
43 | if is_image_file(monkaa_path+'/'+dd+'/right/'+im):
44 | all_right_img.append(monkaa_path+'/'+dd+'/right/'+im)
45 |
46 | flying_path = filepath + [x for x in image if x == 'frames_cleanpass'][0]
47 | flying_disp = filepath + [x for x in disp if x == 'frames_disparity'][0]
48 | flying_dir = flying_path+'/TRAIN/'
49 | subdir = ['A','B','C']
50 |
51 | for ss in subdir:
52 | flying = os.listdir(flying_dir+ss)
53 |
54 | for ff in flying:
55 | imm_l = os.listdir(flying_dir+ss+'/'+ff+'/left/')
56 | for im in imm_l:
57 | if is_image_file(flying_dir+ss+'/'+ff+'/left/'+im):
58 | all_left_img.append(flying_dir+ss+'/'+ff+'/left/'+im)
59 |
60 | all_left_disp.append(flying_disp+'/TRAIN/'+ss+'/'+ff+'/left/'+im.split(".")[0]+'.pfm')
61 |
62 | if is_image_file(flying_dir+ss+'/'+ff+'/right/'+im):
63 | all_right_img.append(flying_dir+ss+'/'+ff+'/right/'+im)
64 |
65 | flying_dir = flying_path+'/TEST/'
66 |
67 | subdir = ['A','B','C']
68 |
69 | for ss in subdir:
70 | flying = os.listdir(flying_dir+ss)
71 |
72 | for ff in flying:
73 | imm_l = os.listdir(flying_dir+ss+'/'+ff+'/left/')
74 | for im in imm_l:
75 | if is_image_file(flying_dir+ss+'/'+ff+'/left/'+im):
76 | test_left_img.append(flying_dir+ss+'/'+ff+'/left/'+im)
77 |
78 | test_left_disp.append(flying_disp+'/TEST/'+ss+'/'+ff+'/left/'+im.split(".")[0]+'.pfm')
79 |
80 | if is_image_file(flying_dir+ss+'/'+ff+'/right/'+im):
81 | test_right_img.append(flying_dir+ss+'/'+ff+'/right/'+im)
82 |
83 |
84 |
85 | driving_dir = filepath + [x for x in image if 'driving' in x][0] + '/'
86 | driving_disp = filepath + [x for x in disp if 'driving' in x][0]
87 |
88 | subdir1 = ['35mm_focallength','15mm_focallength']
89 | subdir2 = ['scene_backwards','scene_forwards']
90 | subdir3 = ['fast','slow']
91 |
92 | for i in subdir1:
93 | for j in subdir2:
94 | for k in subdir3:
95 | imm_l = os.listdir(driving_dir+i+'/'+j+'/'+k+'/left/')
96 | for im in imm_l:
97 | if is_image_file(driving_dir+i+'/'+j+'/'+k+'/left/'+im):
98 | all_left_img.append(driving_dir+i+'/'+j+'/'+k+'/left/'+im)
99 | all_left_disp.append(driving_disp+'/'+i+'/'+j+'/'+k+'/left/'+im.split(".")[0]+'.pfm')
100 |
101 | if is_image_file(driving_dir+i+'/'+j+'/'+k+'/right/'+im):
102 | all_right_img.append(driving_dir+i+'/'+j+'/'+k+'/right/'+im)
103 |
104 |
105 | return all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp
106 |
107 |
108 |
--------------------------------------------------------------------------------
/psmnet/dataloader/preprocess.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms as transforms
3 | import random
4 |
5 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406],
6 | 'std': [0.229, 0.224, 0.225]}
7 |
8 | #__imagenet_stats = {'mean': [0.5, 0.5, 0.5],
9 | # 'std': [0.5, 0.5, 0.5]}
10 |
11 | __imagenet_pca = {
12 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
13 | 'eigvec': torch.Tensor([
14 | [-0.5675, 0.7192, 0.4009],
15 | [-0.5808, -0.0045, -0.8140],
16 | [-0.5836, -0.6948, 0.4203],
17 | ])
18 | }
19 |
20 |
21 | def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats):
22 | t_list = [
23 | transforms.ToTensor(),
24 | transforms.Normalize(**normalize),
25 | ]
26 | #if scale_size != input_size:
27 | #t_list = [transforms.Scale((960,540))] + t_list
28 |
29 | return transforms.Compose(t_list)
30 |
31 |
32 | def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats):
33 | t_list = [
34 | transforms.RandomCrop(input_size),
35 | transforms.ToTensor(),
36 | transforms.Normalize(**normalize),
37 | ]
38 | if scale_size != input_size:
39 | t_list = [transforms.Scale(scale_size)] + t_list
40 |
41 | transforms.Compose(t_list)
42 |
43 |
44 | def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats):
45 | padding = int((scale_size - input_size) / 2)
46 | return transforms.Compose([
47 | transforms.RandomCrop(input_size, padding=padding),
48 | transforms.RandomHorizontalFlip(),
49 | transforms.ToTensor(),
50 | transforms.Normalize(**normalize),
51 | ])
52 |
53 |
54 | def inception_preproccess(input_size, normalize=__imagenet_stats):
55 | return transforms.Compose([
56 | transforms.RandomSizedCrop(input_size),
57 | transforms.RandomHorizontalFlip(),
58 | transforms.ToTensor(),
59 | transforms.Normalize(**normalize)
60 | ])
61 | def inception_color_preproccess(input_size, normalize=__imagenet_stats):
62 | return transforms.Compose([
63 | #transforms.RandomSizedCrop(input_size),
64 | #transforms.RandomHorizontalFlip(),
65 | transforms.ToTensor(),
66 | ColorJitter(
67 | brightness=0.4,
68 | contrast=0.4,
69 | saturation=0.4,
70 | ),
71 | Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']),
72 | transforms.Normalize(**normalize)
73 | ])
74 |
75 |
76 | def get_transform(name='imagenet', input_size=None,
77 | scale_size=None, normalize=None, augment=True):
78 | normalize = __imagenet_stats
79 | input_size = 256
80 | if augment:
81 | return inception_color_preproccess(input_size, normalize=normalize)
82 | else:
83 | return scale_crop(input_size=input_size,
84 | scale_size=scale_size, normalize=normalize)
85 |
86 |
87 |
88 |
89 | class Lighting(object):
90 | """Lighting noise(AlexNet - style PCA - based noise)"""
91 |
92 | def __init__(self, alphastd, eigval, eigvec):
93 | self.alphastd = alphastd
94 | self.eigval = eigval
95 | self.eigvec = eigvec
96 |
97 | def __call__(self, img):
98 | if self.alphastd == 0:
99 | return img
100 |
101 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
102 | rgb = self.eigvec.type_as(img).clone()\
103 | .mul(alpha.view(1, 3).expand(3, 3))\
104 | .mul(self.eigval.view(1, 3).expand(3, 3))\
105 | .sum(1).squeeze()
106 |
107 | return img.add(rgb.view(3, 1, 1).expand_as(img))
108 |
109 |
110 | class Grayscale(object):
111 |
112 | def __call__(self, img):
113 | gs = img.clone()
114 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2])
115 | gs[1].copy_(gs[0])
116 | gs[2].copy_(gs[0])
117 | return gs
118 |
119 |
120 | class Saturation(object):
121 |
122 | def __init__(self, var):
123 | self.var = var
124 |
125 | def __call__(self, img):
126 | gs = Grayscale()(img)
127 | alpha = random.uniform(0, self.var)
128 | return img.lerp(gs, alpha)
129 |
130 |
131 | class Brightness(object):
132 |
133 | def __init__(self, var):
134 | self.var = var
135 |
136 | def __call__(self, img):
137 | gs = img.new().resize_as_(img).zero_()
138 | alpha = random.uniform(0, self.var)
139 | return img.lerp(gs, alpha)
140 |
141 |
142 | class Contrast(object):
143 |
144 | def __init__(self, var):
145 | self.var = var
146 |
147 | def __call__(self, img):
148 | gs = Grayscale()(img)
149 | gs.fill_(gs.mean())
150 | alpha = random.uniform(0, self.var)
151 | return img.lerp(gs, alpha)
152 |
153 |
154 | class RandomOrder(object):
155 | """ Composes several transforms together in random order.
156 | """
157 |
158 | def __init__(self, transforms):
159 | self.transforms = transforms
160 |
161 | def __call__(self, img):
162 | if self.transforms is None:
163 | return img
164 | order = torch.randperm(len(self.transforms))
165 | for i in order:
166 | img = self.transforms[i](img)
167 | return img
168 |
169 |
170 | class ColorJitter(RandomOrder):
171 |
172 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4):
173 | self.transforms = []
174 | if brightness != 0:
175 | self.transforms.append(Brightness(brightness))
176 | if contrast != 0:
177 | self.transforms.append(Contrast(contrast))
178 | if saturation != 0:
179 | self.transforms.append(Saturation(saturation))
180 |
--------------------------------------------------------------------------------
/psmnet/dataloader/readpfm.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 | import sys
4 |
5 |
6 | def readPFM(file):
7 | file = open(file, 'rb')
8 |
9 | color = None
10 | width = None
11 | height = None
12 | scale = None
13 | endian = None
14 |
15 | header = file.readline().rstrip()
16 | if header == 'PF':
17 | color = True
18 | elif header == 'Pf':
19 | color = False
20 | else:
21 | raise Exception('Not a PFM file.')
22 |
23 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline())
24 | if dim_match:
25 | width, height = map(int, dim_match.groups())
26 | else:
27 | raise Exception('Malformed PFM header.')
28 |
29 | scale = float(file.readline().rstrip())
30 | if scale < 0: # little-endian
31 | endian = '<'
32 | scale = -scale
33 | else:
34 | endian = '>' # big-endian
35 |
36 | data = np.fromfile(file, endian + 'f')
37 | shape = (height, width, 3) if color else (height, width)
38 |
39 | data = np.reshape(data, shape)
40 | data = np.flipud(data)
41 | return data, scale
42 |
43 |
--------------------------------------------------------------------------------
/psmnet/finetune_3d.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import argparse
4 | import os
5 | import time
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.nn.parallel
12 | import torch.optim as optim
13 | import torch.utils.data
14 | from torch.autograd import Variable
15 |
16 | import logger
17 | from dataloader import KITTILoader3D as ls
18 | from dataloader import KITTILoader_dataset3d as DA
19 | from models import *
20 |
21 | parser = argparse.ArgumentParser(description='PSMNet')
22 | parser.add_argument('--maxdisp', type=int, default=192,
23 | help='maxium disparity')
24 | parser.add_argument('--model', default='stackhourglass',
25 | help='select model')
26 | parser.add_argument('--datatype', default='2015',
27 | help='datapath')
28 | parser.add_argument('--datapath', default='/media/jiaren/ImageNet/data_scene_flow_2015/training/',
29 | help='datapath')
30 | parser.add_argument('--epochs', type=int, default=300,
31 | help='number of epochs to train')
32 | parser.add_argument('--loadmodel', default='./trained/submission_model.tar',
33 | help='load model')
34 | parser.add_argument('--savemodel', default='./',
35 | help='save model')
36 | parser.add_argument('--no-cuda', action='store_true', default=False,
37 | help='enables CUDA training')
38 | parser.add_argument('--seed', type=int, default=1, metavar='S',
39 | help='random seed (default: 1)')
40 | parser.add_argument('--lr_scale', type=int, default=200, metavar='S',
41 | help='random seed (default: 1)')
42 | parser.add_argument('--split_file', default='Kitti/object/train.txt',
43 | help='save model')
44 | parser.add_argument('--btrain', type=int, default=4)
45 | parser.add_argument('--start_epoch', type=int, default=1)
46 |
47 | args = parser.parse_args()
48 | args.cuda = not args.no_cuda and torch.cuda.is_available()
49 | torch.manual_seed(args.seed)
50 | if args.cuda:
51 | torch.cuda.manual_seed(args.seed)
52 |
53 | if not os.path.isdir(args.savemodel):
54 | os.makedirs(args.savemodel)
55 | print(os.path.join(args.savemodel, 'training.log'))
56 | log = logger.setup_logger(os.path.join(args.savemodel, 'training.log'))
57 |
58 | all_left_img, all_right_img, all_left_disp, = ls.dataloader(args.datapath,
59 | args.split_file)
60 |
61 | TrainImgLoader = torch.utils.data.DataLoader(
62 | DA.myImageFloder(all_left_img, all_right_img, all_left_disp, True),
63 | batch_size=args.btrain, shuffle=True, num_workers=14, drop_last=False)
64 |
65 | if args.model == 'stackhourglass':
66 | model = stackhourglass(args.maxdisp)
67 | elif args.model == 'basic':
68 | model = basic(args.maxdisp)
69 | else:
70 | print('no model')
71 |
72 | if args.cuda:
73 | model = nn.DataParallel(model)
74 | model.cuda()
75 |
76 | if args.loadmodel is not None:
77 | log.info('load model ' + args.loadmodel)
78 | state_dict = torch.load(args.loadmodel)
79 | model.load_state_dict(state_dict['state_dict'])
80 |
81 | print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
82 |
83 | optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999))
84 |
85 |
86 | def train(imgL, imgR, disp_L):
87 | model.train()
88 | imgL = Variable(torch.FloatTensor(imgL))
89 | imgR = Variable(torch.FloatTensor(imgR))
90 | disp_L = Variable(torch.FloatTensor(disp_L))
91 |
92 | if args.cuda:
93 | imgL, imgR, disp_true = imgL.cuda(), imgR.cuda(), disp_L.cuda()
94 |
95 | # ---------
96 | mask = (disp_true > 0)
97 | mask.detach_()
98 | # ----
99 |
100 | optimizer.zero_grad()
101 |
102 | if args.model == 'stackhourglass':
103 | output1, output2, output3 = model(imgL, imgR)
104 | output1 = torch.squeeze(output1, 1)
105 | output2 = torch.squeeze(output2, 1)
106 | output3 = torch.squeeze(output3, 1)
107 | loss = 0.5 * F.smooth_l1_loss(output1[mask], disp_true[mask], size_average=True) + 0.7 * F.smooth_l1_loss(
108 | output2[mask], disp_true[mask], size_average=True) + F.smooth_l1_loss(output3[mask], disp_true[mask],
109 | size_average=True)
110 | elif args.model == 'basic':
111 | output = model(imgL, imgR)
112 | output = torch.squeeze(output, 1)
113 | loss = F.smooth_l1_loss(output[mask], disp_true[mask], size_average=True)
114 |
115 | loss.backward()
116 | optimizer.step()
117 |
118 | return loss.data[0]
119 |
120 |
121 | def test(imgL, imgR, disp_true):
122 | model.eval()
123 | imgL = Variable(torch.FloatTensor(imgL))
124 | imgR = Variable(torch.FloatTensor(imgR))
125 | if args.cuda:
126 | imgL, imgR = imgL.cuda(), imgR.cuda()
127 |
128 | with torch.no_grad():
129 | output3 = model(imgL, imgR)
130 |
131 | pred_disp = output3.data.cpu()
132 |
133 | # computing 3-px error#
134 | true_disp = disp_true
135 | index = np.argwhere(true_disp > 0)
136 | disp_true[index[0][:], index[1][:], index[2][:]] = np.abs(
137 | true_disp[index[0][:], index[1][:], index[2][:]] - pred_disp[index[0][:], index[1][:], index[2][:]])
138 | correct = (disp_true[index[0][:], index[1][:], index[2][:]] < 3) | (
139 | disp_true[index[0][:], index[1][:], index[2][:]] < true_disp[
140 | index[0][:], index[1][:], index[2][:]] * 0.05)
141 | torch.cuda.empty_cache()
142 |
143 | return 1 - (float(torch.sum(correct)) / float(len(index[0])))
144 |
145 |
146 | def adjust_learning_rate(optimizer, epoch):
147 | if epoch <= args.lr_scale:
148 | lr = 0.001
149 | else:
150 | lr = 0.0001
151 | for param_group in optimizer.param_groups:
152 | param_group['lr'] = lr
153 |
154 |
155 | def main():
156 | max_acc = 0
157 | max_epo = 0
158 | start_full_time = time.time()
159 |
160 | for epoch in range(args.start_epoch, args.epochs + 1):
161 | total_train_loss = 0
162 | adjust_learning_rate(optimizer, epoch)
163 |
164 | ## training ##
165 | for batch_idx, (imgL_crop, imgR_crop, disp_crop_L) in enumerate(TrainImgLoader):
166 | start_time = time.time()
167 |
168 | loss = train(imgL_crop, imgR_crop, disp_crop_L)
169 | print('Iter %d training loss = %.3f , time = %.2f' % (batch_idx, loss, time.time() - start_time))
170 | total_train_loss += loss
171 | print('epoch %d total training loss = %.3f' % (epoch, total_train_loss / len(TrainImgLoader)))
172 |
173 | # SAVE
174 | if not os.path.isdir(args.savemodel):
175 | os.makedirs(args.savemodel)
176 | savefilename = args.savemodel + '/finetune_' + str(epoch) + '.tar'
177 | torch.save({
178 | 'epoch': epoch,
179 | 'state_dict': model.state_dict(),
180 | 'train_loss': total_train_loss / len(TrainImgLoader),
181 | }, savefilename)
182 |
183 |
184 | class AverageMeter(object):
185 | """Computes and stores the average and current value"""
186 |
187 | def __init__(self):
188 | self.reset()
189 |
190 | def reset(self):
191 | self.val = 0
192 | self.avg = 0
193 | self.sum = 0
194 | self.count = 0
195 |
196 | def update(self, val, n=1):
197 | self.val = val
198 | self.sum += val * n
199 | self.count += n
200 | self.avg = self.sum / self.count
201 |
202 |
203 | if __name__ == '__main__':
204 | main()
205 |
--------------------------------------------------------------------------------
/psmnet/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 |
5 | def setup_logger(filepath):
6 | file_formatter = logging.Formatter(
7 | "[%(asctime)s %(filename)s:%(lineno)s] %(levelname)-8s %(message)s",
8 | datefmt='%Y-%m-%d %H:%M:%S',
9 | )
10 | logger = logging.getLogger('example')
11 | handler = logging.StreamHandler()
12 | handler.setFormatter(file_formatter)
13 | logger.addHandler(handler)
14 |
15 | file_handle_name = "file"
16 | if file_handle_name in [h.name for h in logger.handlers]:
17 | return
18 | if os.path.dirname(filepath) is not '':
19 | if not os.path.isdir(os.path.dirname(filepath)):
20 | os.makedirs(os.path.dirname(filepath))
21 | file_handle = logging.FileHandler(filename=filepath, mode="a")
22 | file_handle.set_name(file_handle_name)
23 | file_handle.setFormatter(file_formatter)
24 | logger.addHandler(file_handle)
25 | logger.setLevel(logging.DEBUG)
26 | return logger
--------------------------------------------------------------------------------
/psmnet/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic import PSMNet as basic
2 | from .stackhourglass import PSMNet as stackhourglass
3 |
4 |
--------------------------------------------------------------------------------
/psmnet/models/basic.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import torch.nn as nn
4 | import torch.utils.data
5 | from torch.autograd import Variable
6 | import torch.nn.functional as F
7 | import math
8 | from submodule import *
9 |
10 | class PSMNet(nn.Module):
11 | def __init__(self, maxdisp):
12 | super(PSMNet, self).__init__()
13 | self.maxdisp = maxdisp
14 | self.feature_extraction = feature_extraction()
15 |
16 | ########
17 | self.dres0 = nn.Sequential(convbn_3d(64, 32, 3, 1, 1),
18 | nn.ReLU(inplace=True),
19 | convbn_3d(32, 32, 3, 1, 1),
20 | nn.ReLU(inplace=True))
21 |
22 | self.dres1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
23 | nn.ReLU(inplace=True),
24 | convbn_3d(32, 32, 3, 1, 1))
25 |
26 | self.dres2 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
27 | nn.ReLU(inplace=True),
28 | convbn_3d(32, 32, 3, 1, 1))
29 |
30 | self.dres3 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
31 | nn.ReLU(inplace=True),
32 | convbn_3d(32, 32, 3, 1, 1))
33 |
34 | self.dres4 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
35 | nn.ReLU(inplace=True),
36 | convbn_3d(32, 32, 3, 1, 1))
37 |
38 | self.classify = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
39 | nn.ReLU(inplace=True),
40 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1,bias=False))
41 |
42 |
43 | for m in self.modules():
44 | if isinstance(m, nn.Conv2d):
45 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
46 | m.weight.data.normal_(0, math.sqrt(2. / n))
47 | elif isinstance(m, nn.Conv3d):
48 | n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels
49 | m.weight.data.normal_(0, math.sqrt(2. / n))
50 | elif isinstance(m, nn.BatchNorm2d):
51 | m.weight.data.fill_(1)
52 | m.bias.data.zero_()
53 | elif isinstance(m, nn.BatchNorm3d):
54 | m.weight.data.fill_(1)
55 | m.bias.data.zero_()
56 | elif isinstance(m, nn.Linear):
57 | m.bias.data.zero_()
58 |
59 |
60 | def forward(self, left, right):
61 |
62 | refimg_fea = self.feature_extraction(left)
63 | targetimg_fea = self.feature_extraction(right)
64 |
65 | #matching
66 | cost = Variable(torch.FloatTensor(refimg_fea.size()[0], refimg_fea.size()[1]*2, self.maxdisp/4, refimg_fea.size()[2], refimg_fea.size()[3]).zero_(), volatile= not self.training).cuda()
67 |
68 | for i in range(self.maxdisp/4):
69 | if i > 0 :
70 | cost[:, :refimg_fea.size()[1], i, :,i:] = refimg_fea[:,:,:,i:]
71 | cost[:, refimg_fea.size()[1]:, i, :,i:] = targetimg_fea[:,:,:,:-i]
72 | else:
73 | cost[:, :refimg_fea.size()[1], i, :,:] = refimg_fea
74 | cost[:, refimg_fea.size()[1]:, i, :,:] = targetimg_fea
75 | cost = cost.contiguous()
76 |
77 | cost0 = self.dres0(cost)
78 | cost0 = self.dres1(cost0) + cost0
79 | cost0 = self.dres2(cost0) + cost0
80 | cost0 = self.dres3(cost0) + cost0
81 | cost0 = self.dres4(cost0) + cost0
82 |
83 | cost = self.classify(cost0)
84 | cost = F.upsample(cost, [self.maxdisp,left.size()[2],left.size()[3]], mode='trilinear')
85 | cost = torch.squeeze(cost,1)
86 | pred = F.softmax(cost)
87 | pred = disparityregression(self.maxdisp)(pred)
88 |
89 | return pred
90 |
--------------------------------------------------------------------------------
/psmnet/models/stackhourglass.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | from submodule import *
4 | import torch
5 | import torch.nn as nn
6 | import torch.utils.data
7 | from torch.autograd import Variable
8 | import torch.nn.functional as F
9 | import math
10 |
11 |
12 | class hourglass(nn.Module):
13 | def __init__(self, inplanes):
14 | super(hourglass, self).__init__()
15 |
16 | self.conv1 = nn.Sequential(convbn_3d(inplanes, inplanes * 2, kernel_size=3, stride=2, pad=1),
17 | nn.ReLU(inplace=True))
18 |
19 | self.conv2 = convbn_3d(inplanes * 2, inplanes * 2, kernel_size=3, stride=1, pad=1)
20 |
21 | self.conv3 = nn.Sequential(convbn_3d(inplanes * 2, inplanes * 2, kernel_size=3, stride=2, pad=1),
22 | nn.ReLU(inplace=True))
23 |
24 | self.conv4 = nn.Sequential(convbn_3d(inplanes * 2, inplanes * 2, kernel_size=3, stride=1, pad=1),
25 | nn.ReLU(inplace=True))
26 |
27 | self.conv5 = nn.Sequential(
28 | nn.ConvTranspose3d(inplanes * 2, inplanes * 2, kernel_size=3, padding=1, output_padding=1, stride=2,
29 | bias=False),
30 | nn.BatchNorm3d(inplanes * 2)) # +conv2
31 |
32 | self.conv6 = nn.Sequential(
33 | nn.ConvTranspose3d(inplanes * 2, inplanes, kernel_size=3, padding=1, output_padding=1, stride=2,
34 | bias=False),
35 | nn.BatchNorm3d(inplanes)) # +x
36 |
37 | def forward(self, x, presqu, postsqu):
38 |
39 | out = self.conv1(x) # in:1/4 out:1/8
40 | pre = self.conv2(out) # in:1/8 out:1/8
41 | if postsqu is not None:
42 | pre = F.relu(pre + postsqu, inplace=True)
43 | else:
44 | pre = F.relu(pre, inplace=True)
45 |
46 | out = self.conv3(pre) # in:1/8 out:1/16
47 | out = self.conv4(out) # in:1/16 out:1/16
48 |
49 | if presqu is not None:
50 | post = F.relu(self.conv5(out) + presqu, inplace=True) # in:1/16 out:1/8
51 | else:
52 | post = F.relu(self.conv5(out) + pre, inplace=True)
53 |
54 | out = self.conv6(post) # in:1/8 out:1/4
55 |
56 | return out, pre, post
57 |
58 |
59 | class PSMNet(nn.Module):
60 | def __init__(self, maxdisp):
61 | super(PSMNet, self).__init__()
62 | self.maxdisp = maxdisp
63 |
64 | self.feature_extraction = feature_extraction()
65 |
66 | self.dres0 = nn.Sequential(convbn_3d(64, 32, 3, 1, 1),
67 | nn.ReLU(inplace=True),
68 | convbn_3d(32, 32, 3, 1, 1),
69 | nn.ReLU(inplace=True))
70 |
71 | self.dres1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
72 | nn.ReLU(inplace=True),
73 | convbn_3d(32, 32, 3, 1, 1))
74 |
75 | self.dres2 = hourglass(32)
76 |
77 | self.dres3 = hourglass(32)
78 |
79 | self.dres4 = hourglass(32)
80 |
81 | self.classif1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
82 | nn.ReLU(inplace=True),
83 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False))
84 |
85 | self.classif2 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
86 | nn.ReLU(inplace=True),
87 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False))
88 |
89 | self.classif3 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
90 | nn.ReLU(inplace=True),
91 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False))
92 |
93 | for m in self.modules():
94 | if isinstance(m, nn.Conv2d):
95 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
96 | m.weight.data.normal_(0, math.sqrt(2. / n))
97 | elif isinstance(m, nn.Conv3d):
98 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
99 | m.weight.data.normal_(0, math.sqrt(2. / n))
100 | elif isinstance(m, nn.BatchNorm2d):
101 | m.weight.data.fill_(1)
102 | m.bias.data.zero_()
103 | elif isinstance(m, nn.BatchNorm3d):
104 | m.weight.data.fill_(1)
105 | m.bias.data.zero_()
106 | elif isinstance(m, nn.Linear):
107 | m.bias.data.zero_()
108 |
109 | def forward(self, left, right):
110 |
111 | refimg_fea = self.feature_extraction(left)
112 | targetimg_fea = self.feature_extraction(right)
113 |
114 | # matching
115 | cost = Variable(
116 | torch.FloatTensor(refimg_fea.size()[0], refimg_fea.size()[1] * 2, self.maxdisp / 4, refimg_fea.size()[2],
117 | refimg_fea.size()[3]).zero_()).cuda()
118 |
119 | for i in range(self.maxdisp / 4):
120 | if i > 0:
121 | cost[:, :refimg_fea.size()[1], i, :, i:] = refimg_fea[:, :, :, i:]
122 | cost[:, refimg_fea.size()[1]:, i, :, i:] = targetimg_fea[:, :, :, :-i]
123 | else:
124 | cost[:, :refimg_fea.size()[1], i, :, :] = refimg_fea
125 | cost[:, refimg_fea.size()[1]:, i, :, :] = targetimg_fea
126 | cost = cost.contiguous()
127 |
128 | cost0 = self.dres0(cost)
129 | cost0 = self.dres1(cost0) + cost0
130 |
131 | out1, pre1, post1 = self.dres2(cost0, None, None)
132 | out1 = out1 + cost0
133 |
134 | out2, pre2, post2 = self.dres3(out1, pre1, post1)
135 | out2 = out2 + cost0
136 |
137 | out3, pre3, post3 = self.dres4(out2, pre1, post2)
138 | out3 = out3 + cost0
139 |
140 | cost1 = self.classif1(out1)
141 | cost2 = self.classif2(out2) + cost1
142 | cost3 = self.classif3(out3) + cost2
143 |
144 | if self.training:
145 | cost1 = F.upsample(cost1, [self.maxdisp, left.size()[2], left.size()[3]], mode='trilinear')
146 | cost2 = F.upsample(cost2, [self.maxdisp, left.size()[2], left.size()[3]], mode='trilinear')
147 |
148 | cost1 = torch.squeeze(cost1, 1)
149 | pred1 = F.softmax(cost1, dim=1)
150 | pred1 = disparityregression(self.maxdisp)(pred1)
151 |
152 | cost2 = torch.squeeze(cost2, 1)
153 | pred2 = F.softmax(cost2, dim=1)
154 | pred2 = disparityregression(self.maxdisp)(pred2)
155 |
156 | cost3 = F.upsample(cost3, [self.maxdisp, left.size()[2], left.size()[3]], mode='trilinear')
157 | cost3 = torch.squeeze(cost3, 1)
158 | pred3 = F.softmax(cost3, dim=1)
159 | pred3 = disparityregression(self.maxdisp)(pred3)
160 |
161 | if self.training:
162 | return pred1, pred2, pred3
163 | else:
164 | return pred3
165 |
--------------------------------------------------------------------------------
/psmnet/models/submodule.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import torch.nn as nn
4 | import torch.utils.data
5 | from torch.autograd import Variable
6 | import torch.nn.functional as F
7 | import math
8 | import numpy as np
9 |
10 | def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation):
11 |
12 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad, dilation = dilation, bias=False),
13 | nn.BatchNorm2d(out_planes))
14 |
15 |
16 | def convbn_3d(in_planes, out_planes, kernel_size, stride, pad):
17 |
18 | return nn.Sequential(nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, padding=pad, stride=stride,bias=False),
19 | nn.BatchNorm3d(out_planes))
20 |
21 | class BasicBlock(nn.Module):
22 | expansion = 1
23 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation):
24 | super(BasicBlock, self).__init__()
25 |
26 | self.conv1 = nn.Sequential(convbn(inplanes, planes, 3, stride, pad, dilation),
27 | nn.ReLU(inplace=True))
28 |
29 | self.conv2 = convbn(planes, planes, 3, 1, pad, dilation)
30 |
31 | self.downsample = downsample
32 | self.stride = stride
33 |
34 | def forward(self, x):
35 | out = self.conv1(x)
36 | out = self.conv2(out)
37 |
38 | if self.downsample is not None:
39 | x = self.downsample(x)
40 |
41 | out += x
42 |
43 | return out
44 |
45 | class matchshifted(nn.Module):
46 | def __init__(self):
47 | super(matchshifted, self).__init__()
48 |
49 | def forward(self, left, right, shift):
50 | batch, filters, height, width = left.size()
51 | shifted_left = F.pad(torch.index_select(left, 3, Variable(torch.LongTensor([i for i in range(shift,width)])).cuda()),(shift,0,0,0))
52 | shifted_right = F.pad(torch.index_select(right, 3, Variable(torch.LongTensor([i for i in range(width-shift)])).cuda()),(shift,0,0,0))
53 | out = torch.cat((shifted_left,shifted_right),1).view(batch,filters*2,1,height,width)
54 | return out
55 |
56 | class disparityregression(nn.Module):
57 | def __init__(self, maxdisp):
58 | super(disparityregression, self).__init__()
59 | self.disp = Variable(torch.Tensor(np.reshape(np.array(range(maxdisp)),[1,maxdisp,1,1])).cuda(), requires_grad=False)
60 |
61 | def forward(self, x):
62 | disp = self.disp.repeat(x.size()[0],1,x.size()[2],x.size()[3])
63 | out = torch.sum(x*disp,1)
64 | return out
65 |
66 | class feature_extraction(nn.Module):
67 | def __init__(self):
68 | super(feature_extraction, self).__init__()
69 | self.inplanes = 32
70 | self.firstconv = nn.Sequential(convbn(3, 32, 3, 2, 1, 1),
71 | nn.ReLU(inplace=True),
72 | convbn(32, 32, 3, 1, 1, 1),
73 | nn.ReLU(inplace=True),
74 | convbn(32, 32, 3, 1, 1, 1),
75 | nn.ReLU(inplace=True))
76 |
77 | self.layer1 = self._make_layer(BasicBlock, 32, 3, 1,1,1)
78 | self.layer2 = self._make_layer(BasicBlock, 64, 16, 2,1,1)
79 | self.layer3 = self._make_layer(BasicBlock, 128, 3, 1,1,1)
80 | self.layer4 = self._make_layer(BasicBlock, 128, 3, 1,1,2)
81 |
82 | self.branch1 = nn.Sequential(nn.AvgPool2d((64, 64), stride=(64,64)),
83 | convbn(128, 32, 1, 1, 0, 1),
84 | nn.ReLU(inplace=True))
85 |
86 | self.branch2 = nn.Sequential(nn.AvgPool2d((32, 32), stride=(32,32)),
87 | convbn(128, 32, 1, 1, 0, 1),
88 | nn.ReLU(inplace=True))
89 |
90 | self.branch3 = nn.Sequential(nn.AvgPool2d((16, 16), stride=(16,16)),
91 | convbn(128, 32, 1, 1, 0, 1),
92 | nn.ReLU(inplace=True))
93 |
94 | self.branch4 = nn.Sequential(nn.AvgPool2d((8, 8), stride=(8,8)),
95 | convbn(128, 32, 1, 1, 0, 1),
96 | nn.ReLU(inplace=True))
97 |
98 | self.lastconv = nn.Sequential(convbn(320, 128, 3, 1, 1, 1),
99 | nn.ReLU(inplace=True),
100 | nn.Conv2d(128, 32, kernel_size=1, padding=0, stride = 1, bias=False))
101 |
102 | def _make_layer(self, block, planes, blocks, stride, pad, dilation):
103 | downsample = None
104 | if stride != 1 or self.inplanes != planes * block.expansion:
105 | downsample = nn.Sequential(
106 | nn.Conv2d(self.inplanes, planes * block.expansion,
107 | kernel_size=1, stride=stride, bias=False),
108 | nn.BatchNorm2d(planes * block.expansion),)
109 |
110 | layers = []
111 | layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation))
112 | self.inplanes = planes * block.expansion
113 | for i in range(1, blocks):
114 | layers.append(block(self.inplanes, planes,1,None,pad,dilation))
115 |
116 | return nn.Sequential(*layers)
117 |
118 | def forward(self, x):
119 | output = self.firstconv(x)
120 | output = self.layer1(output)
121 | output_raw = self.layer2(output)
122 | output = self.layer3(output_raw)
123 | output_skip = self.layer4(output)
124 |
125 |
126 | output_branch1 = self.branch1(output_skip)
127 | output_branch1 = F.upsample(output_branch1, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
128 |
129 | output_branch2 = self.branch2(output_skip)
130 | output_branch2 = F.upsample(output_branch2, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
131 |
132 | output_branch3 = self.branch3(output_skip)
133 | output_branch3 = F.upsample(output_branch3, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
134 |
135 | output_branch4 = self.branch4(output_skip)
136 | output_branch4 = F.upsample(output_branch4, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
137 |
138 | output_feature = torch.cat((output_raw, output_skip, output_branch4, output_branch3, output_branch2, output_branch1), 1)
139 | output_feature = self.lastconv(output_feature)
140 |
141 | return output_feature
142 |
143 |
144 |
145 |
--------------------------------------------------------------------------------
/psmnet/submission.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import os
4 | import random
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.parallel
8 | import torch.backends.cudnn as cudnn
9 | import torch.optim as optim
10 | import torch.utils.data
11 | from torch.autograd import Variable
12 | import torch.nn.functional as F
13 | import skimage
14 | import skimage.io
15 | import skimage.transform
16 | import numpy as np
17 | import time
18 | import math
19 | from utils import preprocess
20 | from models import *
21 |
22 | # 2012 data /media/jiaren/ImageNet/data_scene_flow_2012/testing/
23 |
24 | parser = argparse.ArgumentParser(description='PSMNet')
25 | parser.add_argument('--KITTI', default='2015',
26 | help='KITTI version')
27 | parser.add_argument('--datapath', default='/scratch/datasets/kitti2015/testing/',
28 | help='select model')
29 | parser.add_argument('--loadmodel', default=None,
30 | help='loading model')
31 | parser.add_argument('--model', default='stackhourglass',
32 | help='select model')
33 | parser.add_argument('--maxdisp', type=int, default=192,
34 | help='maxium disparity')
35 | parser.add_argument('--no-cuda', action='store_true', default=False,
36 | help='enables CUDA training')
37 | parser.add_argument('--seed', type=int, default=1, metavar='S',
38 | help='random seed (default: 1)')
39 | parser.add_argument('--save_path', type=str, default='finetune_1000', metavar='S',
40 | help='path to save the predict')
41 | parser.add_argument('--save_figure', action='store_true', help='if true, save the numpy file, not the png file')
42 | args = parser.parse_args()
43 | args.cuda = not args.no_cuda and torch.cuda.is_available()
44 |
45 | torch.manual_seed(args.seed)
46 | if args.cuda:
47 | torch.cuda.manual_seed(args.seed)
48 |
49 | if args.KITTI == '2015':
50 | from dataloader import KITTI_submission_loader as DA
51 | else:
52 | from dataloader import KITTI_submission_loader2012 as DA
53 |
54 |
55 | test_left_img, test_right_img = DA.dataloader(args.datapath)
56 |
57 | if args.model == 'stackhourglass':
58 | model = stackhourglass(args.maxdisp)
59 | elif args.model == 'basic':
60 | model = basic(args.maxdisp)
61 | else:
62 | print('no model')
63 |
64 | model = nn.DataParallel(model, device_ids=[0])
65 | model.cuda()
66 |
67 | if args.loadmodel is not None:
68 | state_dict = torch.load(args.loadmodel)
69 | model.load_state_dict(state_dict['state_dict'])
70 |
71 | print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
72 |
73 | def test(imgL,imgR):
74 | model.eval()
75 |
76 | if args.cuda:
77 | imgL = torch.FloatTensor(imgL).cuda()
78 | imgR = torch.FloatTensor(imgR).cuda()
79 |
80 | imgL, imgR= Variable(imgL), Variable(imgR)
81 |
82 | with torch.no_grad():
83 | output = model(imgL,imgR)
84 | output = torch.squeeze(output)
85 | pred_disp = output.data.cpu().numpy()
86 |
87 | return pred_disp
88 |
89 |
90 | def main():
91 | processed = preprocess.get_transform(augment=False)
92 | if not os.path.isdir(args.save_path):
93 | os.makedirs(args.save_path)
94 |
95 |
96 | for inx in range(len(test_left_img)):
97 |
98 | imgL_o = (skimage.io.imread(test_left_img[inx]).astype('float32'))
99 | imgR_o = (skimage.io.imread(test_right_img[inx]).astype('float32'))
100 | imgL = processed(imgL_o).numpy()
101 | imgR = processed(imgR_o).numpy()
102 | imgL = np.reshape(imgL,[1,3,imgL.shape[1],imgL.shape[2]])
103 | imgR = np.reshape(imgR,[1,3,imgR.shape[1],imgR.shape[2]])
104 |
105 | # pad to (384, 1248)
106 | top_pad = 384-imgL.shape[2]
107 | left_pad = 1248-imgL.shape[3]
108 | imgL = np.lib.pad(imgL,((0,0),(0,0),(top_pad,0),(0,left_pad)),mode='constant',constant_values=0)
109 | imgR = np.lib.pad(imgR,((0,0),(0,0),(top_pad,0),(0,left_pad)),mode='constant',constant_values=0)
110 |
111 | start_time = time.time()
112 | pred_disp = test(imgL,imgR)
113 | print('time = %.2f' %(time.time() - start_time))
114 |
115 | top_pad = 384-imgL_o.shape[0]
116 | left_pad = 1248-imgL_o.shape[1]
117 | img = pred_disp[top_pad:,:-left_pad]
118 | print(test_left_img[inx].split('/')[-1])
119 | if args.save_figure:
120 | skimage.io.imsave(args.save_path+'/'+test_left_img[inx].split('/')[-1],(img*256).astype('uint16'))
121 | else:
122 | np.save(args.save_path+'/'+test_left_img[inx].split('/')[-1][:-4], img)
123 |
124 | if __name__ == '__main__':
125 | main()
126 |
127 |
128 |
129 |
130 |
131 |
132 |
--------------------------------------------------------------------------------
/psmnet/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mileyan/pseudo_lidar/032c7a0d73c3fdf84e934af3f57f8eb489a52906/psmnet/utils/__init__.py
--------------------------------------------------------------------------------
/psmnet/utils/preprocess.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms as transforms
3 | import random
4 |
5 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406],
6 | 'std': [0.229, 0.224, 0.225]}
7 |
8 | #__imagenet_stats = {'mean': [0.5, 0.5, 0.5],
9 | # 'std': [0.5, 0.5, 0.5]}
10 |
11 | __imagenet_pca = {
12 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
13 | 'eigvec': torch.Tensor([
14 | [-0.5675, 0.7192, 0.4009],
15 | [-0.5808, -0.0045, -0.8140],
16 | [-0.5836, -0.6948, 0.4203],
17 | ])
18 | }
19 |
20 |
21 | def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats):
22 | t_list = [
23 | transforms.ToTensor(),
24 | transforms.Normalize(**normalize),
25 | ]
26 | #if scale_size != input_size:
27 | #t_list = [transforms.Scale((960,540))] + t_list
28 |
29 | return transforms.Compose(t_list)
30 |
31 |
32 | def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats):
33 | t_list = [
34 | transforms.RandomCrop(input_size),
35 | transforms.ToTensor(),
36 | transforms.Normalize(**normalize),
37 | ]
38 | if scale_size != input_size:
39 | t_list = [transforms.Scale(scale_size)] + t_list
40 |
41 | transforms.Compose(t_list)
42 |
43 |
44 | def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats):
45 | padding = int((scale_size - input_size) / 2)
46 | return transforms.Compose([
47 | transforms.RandomCrop(input_size, padding=padding),
48 | transforms.RandomHorizontalFlip(),
49 | transforms.ToTensor(),
50 | transforms.Normalize(**normalize),
51 | ])
52 |
53 |
54 | def inception_preproccess(input_size, normalize=__imagenet_stats):
55 | return transforms.Compose([
56 | transforms.RandomSizedCrop(input_size),
57 | transforms.RandomHorizontalFlip(),
58 | transforms.ToTensor(),
59 | transforms.Normalize(**normalize)
60 | ])
61 | def inception_color_preproccess(input_size, normalize=__imagenet_stats):
62 | return transforms.Compose([
63 | #transforms.RandomSizedCrop(input_size),
64 | #transforms.RandomHorizontalFlip(),
65 | transforms.ToTensor(),
66 | ColorJitter(
67 | brightness=0.4,
68 | contrast=0.4,
69 | saturation=0.4,
70 | ),
71 | Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']),
72 | transforms.Normalize(**normalize)
73 | ])
74 |
75 |
76 | def get_transform(name='imagenet', input_size=None,
77 | scale_size=None, normalize=None, augment=True):
78 | normalize = __imagenet_stats
79 | input_size = 256
80 | if augment:
81 | return inception_color_preproccess(input_size, normalize=normalize)
82 | else:
83 | return scale_crop(input_size=input_size,
84 | scale_size=scale_size, normalize=normalize)
85 |
86 |
87 |
88 |
89 | class Lighting(object):
90 | """Lighting noise(AlexNet - style PCA - based noise)"""
91 |
92 | def __init__(self, alphastd, eigval, eigvec):
93 | self.alphastd = alphastd
94 | self.eigval = eigval
95 | self.eigvec = eigvec
96 |
97 | def __call__(self, img):
98 | if self.alphastd == 0:
99 | return img
100 |
101 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
102 | rgb = self.eigvec.type_as(img).clone()\
103 | .mul(alpha.view(1, 3).expand(3, 3))\
104 | .mul(self.eigval.view(1, 3).expand(3, 3))\
105 | .sum(1).squeeze()
106 |
107 | return img.add(rgb.view(3, 1, 1).expand_as(img))
108 |
109 |
110 | class Grayscale(object):
111 |
112 | def __call__(self, img):
113 | gs = img.clone()
114 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2])
115 | gs[1].copy_(gs[0])
116 | gs[2].copy_(gs[0])
117 | return gs
118 |
119 |
120 | class Saturation(object):
121 |
122 | def __init__(self, var):
123 | self.var = var
124 |
125 | def __call__(self, img):
126 | gs = Grayscale()(img)
127 | alpha = random.uniform(0, self.var)
128 | return img.lerp(gs, alpha)
129 |
130 |
131 | class Brightness(object):
132 |
133 | def __init__(self, var):
134 | self.var = var
135 |
136 | def __call__(self, img):
137 | gs = img.new().resize_as_(img).zero_()
138 | alpha = random.uniform(0, self.var)
139 | return img.lerp(gs, alpha)
140 |
141 |
142 | class Contrast(object):
143 |
144 | def __init__(self, var):
145 | self.var = var
146 |
147 | def __call__(self, img):
148 | gs = Grayscale()(img)
149 | gs.fill_(gs.mean())
150 | alpha = random.uniform(0, self.var)
151 | return img.lerp(gs, alpha)
152 |
153 |
154 | class RandomOrder(object):
155 | """ Composes several transforms together in random order.
156 | """
157 |
158 | def __init__(self, transforms):
159 | self.transforms = transforms
160 |
161 | def __call__(self, img):
162 | if self.transforms is None:
163 | return img
164 | order = torch.randperm(len(self.transforms))
165 | for i in order:
166 | img = self.transforms[i](img)
167 | return img
168 |
169 |
170 | class ColorJitter(RandomOrder):
171 |
172 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4):
173 | self.transforms = []
174 | if brightness != 0:
175 | self.transforms.append(Brightness(brightness))
176 | if contrast != 0:
177 | self.transforms.append(Contrast(contrast))
178 | if saturation != 0:
179 | self.transforms.append(Saturation(saturation))
180 |
--------------------------------------------------------------------------------
/psmnet/utils/readpfm.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 | import sys
4 |
5 |
6 | def readPFM(file):
7 | file = open(file, 'rb')
8 |
9 | color = None
10 | width = None
11 | height = None
12 | scale = None
13 | endian = None
14 |
15 | header = file.readline().rstrip()
16 | if header == 'PF':
17 | color = True
18 | elif header == 'Pf':
19 | color = False
20 | else:
21 | raise Exception('Not a PFM file.')
22 |
23 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline())
24 | if dim_match:
25 | width, height = map(int, dim_match.groups())
26 | else:
27 | raise Exception('Malformed PFM header.')
28 |
29 | scale = float(file.readline().rstrip())
30 | if scale < 0: # little-endian
31 | endian = '<'
32 | scale = -scale
33 | else:
34 | endian = '>' # big-endian
35 |
36 | data = np.fromfile(file, endian + 'f')
37 | shape = (height, width, 3) if color else (height, width)
38 |
39 | data = np.reshape(data, shape)
40 | data = np.flipud(data)
41 | return data, scale
42 |
43 |
--------------------------------------------------------------------------------
/visualization/000012.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mileyan/pseudo_lidar/032c7a0d73c3fdf84e934af3f57f8eb489a52906/visualization/000012.bin
--------------------------------------------------------------------------------
/visualization/pyntcloud.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mileyan/pseudo_lidar/032c7a0d73c3fdf84e934af3f57f8eb489a52906/visualization/pyntcloud.png
--------------------------------------------------------------------------------