├── LICENSE
├── README.md
├── co
├── __init__.py
├── args.py
├── cmap.py
├── geometry.py
├── gtimer.py
├── io3d.py
├── metric.py
└── utils.py
├── config.json
├── data
├── __init__.py
├── base_dataset.py
├── create_syn_data.py
├── data_manipulation.py
├── dataset.py
├── default_pattern.png
├── kinect_pattern.png
├── presave_disp.py
├── presave_optical_flow_data.py
└── real_pattern.png
├── model
├── __init__.py
├── ext_functions.py
├── multi_frame_networks.py
├── multi_frame_worker.py
├── networks.py
├── single_frame_worker.py
└── worker.py
├── requirements.txt
└── train_val.py
/LICENSE:
--------------------------------------------------------------------------------
1 | LICENSE FOR ORIGINAL FILE / https://github.com/autonomousvision/connecting_the_dots/blob/master/LICENSE
2 |
3 | MIT License
4 |
5 | Copyright (c) 2019 autonomousvision
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 |
25 |
26 |
27 | LICENSE FOR MODIFICATIONS OF ORIGINAL FILE AND FOR OTHER FILES IN THIS REPOSITORY
28 |
29 | MIT LICENSE
30 |
31 | Copyright 2021, ams International AG
32 |
33 | Permission is hereby granted, free of charge, to any person obtaining a copy
34 | of this software and associated documentation files (the "Software"), to deal
35 | in the Software without restriction, including without limitation the rights
36 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
37 | copies of the Software, and to permit persons to whom the Software is
38 | furnished to do so, subject to the following conditions:
39 |
40 | The above copyright notice and this permission notice shall be included in all
41 | copies or substantial portions of the Software.
42 |
43 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
44 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
45 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
46 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
47 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
48 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
49 | SOFTWARE.
50 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | > # [ICCV 2021] DepthInSpace: Exploitation and Fusion of Multiple Frames of a Video for Structured-Light Depth Estimation
3 | > Mohammad Mahdi Johari, Camilla Carta, François Fleuret
4 | > [Project Page](https://www.idiap.ch/paper/depthinspace/) | [Paper](https://openaccess.thecvf.com/content/ICCV2021/html/Johari_DepthInSpace_Exploitation_and_Fusion_of_Multiple_Video_Frames_for_Structured-Light_ICCV_2021_paper.html)
5 |
6 | ## Dependencies
7 |
8 | The network training/evaluation code is based on `PyTorch` and is tested in the following environment:
9 | ```
10 | Python==3.8.6
11 | PyTorch==1.7.0
12 | CUDA==10.1
13 | ```
14 |
15 | All required packages can be installed with `anaconda`:
16 | ```
17 | conda install --file requirements.txt -c pytorch -c conda-forge
18 | ```
19 |
20 | ### External Libraries
21 | To train and evaluate our method on synthetic datasets, we use the structured light renderer provided by [Connecting the Dots](https://github.com/autonomousvision/connecting_the_dots).
22 | It can be used to render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable projector location.
23 | Furthermore, Our models use some custom layers provided in [Connecting the Dots](https://github.com/autonomousvision/connecting_the_dots).
24 | First, download [ShapeNet V2](https://www.shapenet.org/) and correct `SHAPENET_DIR` in `config.json` accordingly.
25 | Then, to install these dependencies, use the following instructions and set `CTD_DIR` in `config.json` to the path of the cloned [Connecting the Dots](https://github.com/autonomousvision/connecting_the_dots) repository:
26 |
27 | ```
28 | git clone https://github.com/autonomousvision/connecting_the_dots.git
29 | cd connecting_the_dots
30 | cd renderer
31 | make
32 | cd ..
33 | cd data/lcn
34 | python setup.py build_ext --inplace
35 | cd ../..
36 | cd torchext
37 | python setup.py build_ext --inplace
38 | cd ..
39 | ```
40 |
41 | As a preprocessing step, you need to execute [LiteFlowNet](https://github.com/sniklaus/pytorch-liteflownet) software on the data before running our models. To this end, clone [our forked copy of pytorch-liteflownet](https://github.com/MohammadJohari/pytorch-liteflownet) with the following command and set `LITEFLOWNET_DIR` in `config.json` accordingly.
42 | ```
43 | git clone https://github.com/MohammadJohari/pytorch-liteflownet.git
44 | ```
45 | Make sure you comply with the [license](https://github.com/twhui/LiteFlowNet#license-and-citation) terms of LiteFlowNet's original paper.
46 | However, DepthInSpace models are compatible with any external optical flow library.
47 | To use DepthInSpace with other optical flow libraries, you need to modify the code in `data/presave_optical_flow_data.py` and make it compatible with your optical flow model of choice.
48 |
49 | ## Running
50 |
51 |
52 | ### Creating Synthetic Data
53 | The synthetic data will be generated and saved to `DATA_DIR` in `config.json`.
54 | In order to generate the data with the `default` projection dot pattern, change directory to `data` and run
55 |
56 | ```
57 | python create_syn_data.py default
58 | ```
59 |
60 | Other available options for the dot pattern are: `kinect` and `real`, where `real` is the real observed dot pattern in our experiments.
61 |
62 | ### Pre-Saving Optical Flow Data
63 | Before training, optical flow predictions from LiteFlowNet should be pre-saved. To do so, make sure the `DATA_DIR` in `config.json` is correct and run the following command in `data` directory
64 |
65 | ```
66 | python presave_optical_flow_data.py
67 | ```
68 |
69 | ### Training DIS-SF
70 | Note that the weights and state of training of our networks are saved in `OUTPUT_DIR` in `config.json`.
71 | For training the DIS-SF network with an arbitrary batch size (e.g. 8), you can run
72 |
73 | ```
74 | python train_val.py --train_batch_size 8 --architecture single_frame
75 | ```
76 |
77 | ### Training DIS-MF
78 | before training the DIS-MF network, the DIS-SF network must have been trained and its outputs must have been pre-saved.
79 | Make sure the `DATA_DIR` in `config.json` is correct and the trained weights of the DIS-SF network are available in `OUTPUT_DIR` in `config.json`.
80 | Then, you can pre-save the outputs of an specific epoch (e.g. 100) of the DIS-SF network by running the following command in `data` directory
81 |
82 | ```
83 | python presave_disp.py single_frame --epoch 100
84 | ```
85 |
86 | You can then train the DIS-MF network with an arbitrary batch size (e.g. 4) by running
87 |
88 | ```
89 | python train_val.py --train_batch_size 4 --architecture multi_frame
90 | ```
91 | The DIS-MF network can be trained with batch size of 4 on a device with 24 Gigabytes of GPU memory.
92 |
93 | ### Training DIS-FTSF
94 | before training the DIS-FTSF network, the DIS-MF network must have been trained and its outputs must have been pre-saved.
95 | Make sure the `DATA_DIR` in `config.json` is correct and the trained weights of the DIS-MF network are available in `OUTPUT_DIR` in `config.json`.
96 | Then, you can pre-save the outputs of an specific epoch (e.g. 50) of the DIS-MF network by running the following command in `data` directory
97 |
98 | ```
99 | python presave_disp.py multi_frame --epoch 50
100 | ```
101 |
102 | You can then train the DIS-FTSF network with an arbitrary batch size (e.g. 8) by running
103 |
104 | ```
105 | python train_val.py --train_batch_size 8 --architecture single_frame --use_pseudo_gt True
106 | ```
107 |
108 | ### Evaluating the Networks
109 | To evaluate a specific checkpoint of a specific network, e.g. the 50th epoch of the DIS-MF network, one can run
110 | ```
111 | python train_val.py --architecture multi_frame --cmd retest --epoch 50
112 | ```
113 | ### Training/Evaluating on Real Dataset
114 | Our captured real dataset can be downloaded from [here](https://www.idiap.ch/en/dataset/depthinspace/index_html). Make sure to use `--data_type real` when you want to train or evaluate the models on our real dataset to use the same train/test split as in our paper.
115 |
116 | ### Pretrained Networks
117 | The pretrained networks for synthetic datasets with different projection patterns and the real dataset can be found [here](https://drive.google.com/drive/folders/1uiSElbiQhXxag2VIpXy4lOoueoAEh1ak?usp=sharing). In order to use these networks, make sure `OUTPUT_DIR` in `config.json` corresponds to the proper pretrained directory. Then, use the following for the single-frame model:
118 | ```
119 | python train_val.py --architecture single_frame --cmd retest --epoch 0
120 | ```
121 | and the following for the multi-frame model:
122 | ```
123 | python train_val.py --architecture multi_frame --cmd retest --epoch 0
124 | ```
125 | Make sure to add `--data_type real` when you want to use the models on our real dataset to use the same train/test split as in our paper.
126 |
127 | ### Contact
128 | You can contact the author through email: mohammad.johari At idiap.ch.
129 |
130 | ### License
131 | All codes found in this repository are licensed under the [MIT License](LICENSE).
132 |
133 | ## Citing
134 | If you find our work useful, please consider citing:
135 | ```BibTeX
136 | @inproceedings{johari-et-al-2021,
137 | author = {Johari, M. and Carta, C. and Fleuret, F.},
138 | title = {DepthInSpace: Exploitation and Fusion of Multiple Video Frames for Structured-Light Depth Estimation},
139 | booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)},
140 | year = {2021}
141 | }
142 | ```
143 |
144 | ### Acknowledgement
145 | This work was supported by ams OSRAM.
--------------------------------------------------------------------------------
/co/__init__.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # MIT License
5 | #
6 | # Copyright (c) 2019 autonomousvision
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 |
27 | # set matplotlib backend depending on env
28 | import os
29 | import matplotlib
30 | if os.name == 'posix' and "DISPLAY" not in os.environ:
31 | matplotlib.use('Agg')
32 |
33 | from . import geometry
34 | from . import metric
35 | from . import utils
36 | from . import io3d
37 | from . import gtimer
38 | from . import cmap
39 | from . import args
40 |
--------------------------------------------------------------------------------
/co/args.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # MIT License
5 | #
6 | # Copyright (c) 2019 autonomousvision
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 | import argparse
27 | from .utils import str2bool
28 |
29 |
30 | def parse_args():
31 | parser = argparse.ArgumentParser()
32 | #
33 | parser.add_argument('--data_type',
34 | default='synthetic', choices=['synthetic', 'real'], type=str)
35 | #
36 | parser.add_argument('--cmd',
37 | help='Start training or test',
38 | default='resume', choices=['retrain', 'resume', 'retest', 'test_init'], type=str)
39 | parser.add_argument('--epoch',
40 | help='If larger than -1, retest on the specified epoch',
41 | default=-1, type=int)
42 | parser.add_argument('--epochs',
43 | help='Training epochs',
44 | default=100, type=int)
45 | parser.add_argument('--warmup_epochs',
46 | help='Number of epochs where SGM Disparities are used as supervisor when training on the real dataset',
47 | default=150, type=int)
48 | #
49 | parser.add_argument('--lcn_radius',
50 | help='Radius of the window for LCN pre-processing',
51 | default=5, type=int)
52 | parser.add_argument('--max_disp',
53 | help='Maximum disparity',
54 | default=128, type=int)
55 | #
56 | parser.add_argument('--track_length',
57 | help='Track length for geometric loss',
58 | default=4, type=int)
59 | #
60 | parser.add_argument('--train_batch_size',
61 | help='Train Batch Size',
62 | default=8, type=int)
63 | #
64 | parser.add_argument('--architecture',
65 | help='The architecture which will be used',
66 | default='single_frame', choices=['single_frame', 'multi_frame'], type=str)
67 | #
68 | parser.add_argument('--use_pseudo_gt',
69 | help='Only applicable in single-frame model',
70 | default=False, type=str2bool)
71 |
72 | args = parser.parse_args()
73 |
74 | return args
--------------------------------------------------------------------------------
/co/cmap.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # MIT License
5 | #
6 | # Copyright (c) 2019 autonomousvision
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 | import numpy as np
27 |
28 | _color_map_errors = np.array([
29 | [149, 54, 49], # 0: log2(x) = -infinity
30 | [180, 117, 69], # 0.0625: log2(x) = -4
31 | [209, 173, 116], # 0.125: log2(x) = -3
32 | [233, 217, 171], # 0.25: log2(x) = -2
33 | [248, 243, 224], # 0.5: log2(x) = -1
34 | [144, 224, 254], # 1.0: log2(x) = 0
35 | [97, 174, 253], # 2.0: log2(x) = 1
36 | [67, 109, 244], # 4.0: log2(x) = 2
37 | [39, 48, 215], # 8.0: log2(x) = 3
38 | [38, 0, 165], # 16.0: log2(x) = 4
39 | [38, 0, 165] # inf: log2(x) = inf
40 | ]).astype(float)
41 |
42 |
43 | def color_error_image(errors, scale=1.2, log_scale=0.25, mask=None, BGR=True):
44 | """
45 | Color an input error map.
46 |
47 | Arguments:
48 | errors -- HxW numpy array of errors
49 | [scale=1] -- scaling the error map (color change at unit error)
50 | [mask=None] -- zero-pixels are masked white in the result
51 | [BGR=True] -- toggle between BGR and RGB
52 |
53 | Returns:
54 | colored_errors -- HxWx3 numpy array visualizing the errors
55 | """
56 |
57 | errors_flat = errors.flatten()
58 | errors_color_indices = np.clip(np.log2(errors_flat / scale + 1e-5) / log_scale + 5, 0, 9)
59 | i0 = np.floor(errors_color_indices).astype(int)
60 | f1 = errors_color_indices - i0.astype(float)
61 | colored_errors_flat = _color_map_errors[i0, :] * (1 - f1).reshape(-1, 1) + _color_map_errors[i0 + 1,
62 | :] * f1.reshape(-1, 1)
63 |
64 | if mask is not None:
65 | colored_errors_flat[mask.flatten() == 0] = 255
66 |
67 | if not BGR:
68 | colored_errors_flat = colored_errors_flat[:, [2, 1, 0]]
69 |
70 | return colored_errors_flat.reshape(errors.shape[0], errors.shape[1], 3).astype(np.int)
71 |
72 |
73 | _color_map_depths = np.array([
74 | [0, 0, 0], # 0.000
75 | [0, 0, 255], # 0.114
76 | [255, 0, 0], # 0.299
77 | [255, 0, 255], # 0.413
78 | [0, 255, 0], # 0.587
79 | [0, 255, 255], # 0.701
80 | [255, 255, 0], # 0.886
81 | [255, 255, 255], # 1.000
82 | [255, 255, 255], # 1.000
83 | ]).astype(float)
84 | _color_map_bincenters = np.array([
85 | 0.0,
86 | 0.114,
87 | 0.299,
88 | 0.413,
89 | 0.587,
90 | 0.701,
91 | 0.886,
92 | 1.000,
93 | 2.000, # doesn't make a difference, just strictly higher than 1
94 | ])
95 |
96 |
97 | def color_depth_map(depths, scale=None):
98 | """
99 | Color an input depth map.
100 |
101 | Arguments:
102 | depths -- HxW numpy array of depths
103 | [scale=None] -- scaling the values (defaults to the maximum depth)
104 |
105 | Returns:
106 | colored_depths -- HxWx3 numpy array visualizing the depths
107 | """
108 |
109 | if scale is None:
110 | scale = depths.max()
111 |
112 | values = np.clip(depths.flatten() / scale, 0, 1)
113 | # for each value, figure out where they fit in in the bincenters: what is the last bincenter smaller than this value?
114 | lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1, -1)) * np.arange(0, 9)).max(axis=1)
115 | lower_bin_value = _color_map_bincenters[lower_bin]
116 | higher_bin_value = _color_map_bincenters[lower_bin + 1]
117 | alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value)
118 | colors = _color_map_depths[lower_bin] * (1 - alphas).reshape(-1, 1) + _color_map_depths[
119 | lower_bin + 1] * alphas.reshape(-1, 1)
120 | return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8)
121 |
122 | # from utils.debug import save_color_numpy
123 | # save_color_numpy(color_depth_map(np.matmul(np.ones((100,1)), np.arange(0,1200).reshape(1,1200)), scale=1000))
124 |
--------------------------------------------------------------------------------
/co/geometry.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # MIT License
5 | #
6 | # Copyright (c) 2019 autonomousvision
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 | import numpy as np
27 |
28 | def nullspace(A, atol=1e-13, rtol=0):
29 | u, s, vh = np.linalg.svd(A)
30 | tol = max(atol, rtol * s[0])
31 | nnz = (s >= tol).sum()
32 | ns = vh[nnz:].conj().T
33 | return ns
34 |
35 | def nearest_orthogonal_matrix(R):
36 | U,S,Vt = np.linalg.svd(R)
37 | return U @ np.eye(3,dtype=R.dtype) @ Vt
38 |
39 | def power_iters(A, n_iters=10):
40 | b = np.random.uniform(-1,1, size=(A.shape[0], A.shape[1], 1))
41 | for iter in range(n_iters):
42 | b = A @ b
43 | b = b / np.linalg.norm(b, axis=1, keepdims=True)
44 | return b
45 |
46 | def rayleigh_quotient(A, b):
47 | return (b.transpose(0,2,1) @ A @ b) / (b.transpose(0,2,1) @ b)
48 |
49 |
50 | def cross_prod_mat(x):
51 | x = x.reshape(-1,3)
52 | X = np.empty((x.shape[0],3,3), dtype=x.dtype)
53 | X[:,0,0] = 0
54 | X[:,0,1] = -x[:,2]
55 | X[:,0,2] = x[:,1]
56 | X[:,1,0] = x[:,2]
57 | X[:,1,1] = 0
58 | X[:,1,2] = -x[:,0]
59 | X[:,2,0] = -x[:,1]
60 | X[:,2,1] = x[:,0]
61 | X[:,2,2] = 0
62 | return X.squeeze()
63 |
64 | def hat_operator(x):
65 | return cross_prod_mat(x)
66 |
67 | def vee_operator(X):
68 | X = X.reshape(-1,3,3)
69 | x = np.empty((X.shape[0], 3), dtype=X.dtype)
70 | x[:,0] = X[:,2,1]
71 | x[:,1] = X[:,0,2]
72 | x[:,2] = X[:,1,0]
73 | return x.squeeze()
74 |
75 |
76 | def rot_x(x, dtype=np.float32):
77 | x = np.array(x, copy=False)
78 | x = x.reshape(-1,1)
79 | R = np.zeros((x.shape[0],3,3), dtype=dtype)
80 | R[:,0,0] = 1
81 | R[:,1,1] = np.cos(x).ravel()
82 | R[:,1,2] = -np.sin(x).ravel()
83 | R[:,2,1] = np.sin(x).ravel()
84 | R[:,2,2] = np.cos(x).ravel()
85 | return R.squeeze()
86 |
87 | def rot_y(y, dtype=np.float32):
88 | y = np.array(y, copy=False)
89 | y = y.reshape(-1,1)
90 | R = np.zeros((y.shape[0],3,3), dtype=dtype)
91 | R[:,0,0] = np.cos(y).ravel()
92 | R[:,0,2] = np.sin(y).ravel()
93 | R[:,1,1] = 1
94 | R[:,2,0] = -np.sin(y).ravel()
95 | R[:,2,2] = np.cos(y).ravel()
96 | return R.squeeze()
97 |
98 | def rot_z(z, dtype=np.float32):
99 | z = np.array(z, copy=False)
100 | z = z.reshape(-1,1)
101 | R = np.zeros((z.shape[0],3,3), dtype=dtype)
102 | R[:,0,0] = np.cos(z).ravel()
103 | R[:,0,1] = -np.sin(z).ravel()
104 | R[:,1,0] = np.sin(z).ravel()
105 | R[:,1,1] = np.cos(z).ravel()
106 | R[:,2,2] = 1
107 | return R.squeeze()
108 |
109 | def xyz_from_rotm(R):
110 | R = R.reshape(-1,3,3)
111 | xyz = np.empty((R.shape[0],3), dtype=R.dtype)
112 | for bidx in range(R.shape[0]):
113 | if R[bidx,0,2] < 1:
114 | if R[bidx,0,2] > -1:
115 | xyz[bidx,1] = np.arcsin(R[bidx,0,2])
116 | xyz[bidx,0] = np.arctan2(-R[bidx,1,2], R[bidx,2,2])
117 | xyz[bidx,2] = np.arctan2(-R[bidx,0,1], R[bidx,0,0])
118 | else:
119 | xyz[bidx,1] = -np.pi/2
120 | xyz[bidx,0] = -np.arctan2(R[bidx,1,0],R[bidx,1,1])
121 | xyz[bidx,2] = 0
122 | else:
123 | xyz[bidx,1] = np.pi/2
124 | xyz[bidx,0] = np.arctan2(R[bidx,1,0], R[bidx,1,1])
125 | xyz[bidx,2] = 0
126 | return xyz.squeeze()
127 |
128 | def zyx_from_rotm(R):
129 | R = R.reshape(-1,3,3)
130 | zyx = np.empty((R.shape[0],3), dtype=R.dtype)
131 | for bidx in range(R.shape[0]):
132 | if R[bidx,2,0] < 1:
133 | if R[bidx,2,0] > -1:
134 | zyx[bidx,1] = np.arcsin(-R[bidx,2,0])
135 | zyx[bidx,0] = np.arctan2(R[bidx,1,0], R[bidx,0,0])
136 | zyx[bidx,2] = np.arctan2(R[bidx,2,1], R[bidx,2,2])
137 | else:
138 | zyx[bidx,1] = np.pi / 2
139 | zyx[bidx,0] = -np.arctan2(-R[bidx,1,2], R[bidx,1,1])
140 | zyx[bidx,2] = 0
141 | else:
142 | zyx[bidx,1] = -np.pi / 2
143 | zyx[bidx,0] = np.arctan2(-R[bidx,1,2], R[bidx,1,1])
144 | zyx[bidx,2] = 0
145 | return zyx.squeeze()
146 |
147 | def rotm_from_xyz(xyz):
148 | xyz = np.array(xyz, copy=False).reshape(-1,3)
149 | return (rot_x(xyz[:,0]) @ rot_y(xyz[:,1]) @ rot_z(xyz[:,2])).squeeze()
150 |
151 | def rotm_from_zyx(zyx):
152 | zyx = np.array(zyx, copy=False).reshape(-1,3)
153 | return (rot_z(zyx[:,0]) @ rot_y(zyx[:,1]) @ rot_x(zyx[:,2])).squeeze()
154 |
155 | def rotm_from_quat(q):
156 | q = q.reshape(-1,4)
157 | w, x, y, z = q[:,0], q[:,1], q[:,2], q[:,3]
158 | R = np.array([
159 | [1 - 2*y*y - 2*z*z, 2*x*y - 2*z*w, 2*x*z + 2*y*w],
160 | [2*x*y + 2*z*w, 1 - 2*x*x - 2*z*z, 2*y*z - 2*x*w],
161 | [2*x*z - 2*y*w, 2*y*z + 2*x*w, 1 - 2*x*x - 2*y*y]
162 | ], dtype=q.dtype)
163 | R = R.transpose((2,0,1))
164 | return R.squeeze()
165 |
166 | def rotm_from_axisangle(a):
167 | # exponential
168 | a = a.reshape(-1,3)
169 | phi = np.linalg.norm(a, axis=1).reshape(-1,1,1)
170 | iphi = np.zeros_like(phi)
171 | np.divide(1, phi, out=iphi, where=phi != 0)
172 | A = cross_prod_mat(a) * iphi
173 | R = np.eye(3, dtype=a.dtype) + np.sin(phi) * A + (1 - np.cos(phi)) * A @ A
174 | return R.squeeze()
175 |
176 | def rotm_from_lookat(dir, up=None):
177 | dir = dir.reshape(-1,3)
178 | if up is None:
179 | up = np.zeros_like(dir)
180 | up[:,1] = 1
181 | dir /= np.linalg.norm(dir, axis=1, keepdims=True)
182 | up /= np.linalg.norm(up, axis=1, keepdims=True)
183 | x = dir[:,None,:] @ cross_prod_mat(up).transpose(0,2,1)
184 | y = x @ cross_prod_mat(dir).transpose(0,2,1)
185 | x = x.squeeze()
186 | y = y.squeeze()
187 | x /= np.linalg.norm(x, axis=1, keepdims=True)
188 | y /= np.linalg.norm(y, axis=1, keepdims=True)
189 | R = np.empty((dir.shape[0],3,3), dtype=dir.dtype)
190 | R[:,0,0] = x[:,0]
191 | R[:,0,1] = y[:,0]
192 | R[:,0,2] = dir[:,0]
193 | R[:,1,0] = x[:,1]
194 | R[:,1,1] = y[:,1]
195 | R[:,1,2] = dir[:,1]
196 | R[:,2,0] = x[:,2]
197 | R[:,2,1] = y[:,2]
198 | R[:,2,2] = dir[:,2]
199 | return R.transpose(0,2,1).squeeze()
200 |
201 | def rotm_distance_identity(R0, R1):
202 | # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
203 | # in [0, 2*sqrt(2)]
204 | R0 = R0.reshape(-1,3,3)
205 | R1 = R1.reshape(-1,3,3)
206 | dists = np.linalg.norm(np.eye(3,dtype=R0.dtype) - R0 @ R1.transpose(0,2,1), axis=(1,2))
207 | return dists.squeeze()
208 |
209 | def rotm_distance_geodesic(R0, R1):
210 | # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
211 | # in [0, pi)
212 | R0 = R0.reshape(-1,3,3)
213 | R1 = R1.reshape(-1,3,3)
214 | RtR = R0 @ R1.transpose(0,2,1)
215 | aa = axisangle_from_rotm(RtR)
216 | S = cross_prod_mat(aa).reshape(-1,3,3)
217 | dists = np.linalg.norm(S, axis=(1,2))
218 | return dists.squeeze()
219 |
220 |
221 |
222 | def axisangle_from_rotm(R):
223 | # logarithm of rotation matrix
224 | # R = R.reshape(-1,3,3)
225 | # tr = np.trace(R, axis1=1, axis2=2)
226 | # phi = np.arccos(np.clip((tr - 1) / 2, -1, 1))
227 | # scale = np.zeros_like(phi)
228 | # div = 2 * np.sin(phi)
229 | # np.divide(phi, div, out=scale, where=np.abs(div) > 1e-6)
230 | # A = (R - R.transpose(0,2,1)) * scale.reshape(-1,1,1)
231 | # aa = np.stack((A[:,2,1], A[:,0,2], A[:,1,0]), axis=1)
232 | # return aa.squeeze()
233 | R = R.reshape(-1,3,3)
234 | omega = np.empty((R.shape[0], 3), dtype=R.dtype)
235 | omega[:,0] = R[:,2,1] - R[:,1,2]
236 | omega[:,1] = R[:,0,2] - R[:,2,0]
237 | omega[:,2] = R[:,1,0] - R[:,0,1]
238 | r = np.linalg.norm(omega, axis=1).reshape(-1,1)
239 | t = np.trace(R, axis1=1, axis2=2).reshape(-1,1)
240 | omega = np.arctan2(r, t-1) * omega
241 | aa = np.zeros_like(omega)
242 | np.divide(omega, r, out=aa, where=r != 0)
243 | return aa.squeeze()
244 |
245 | def axisangle_from_quat(q):
246 | q = q.reshape(-1,4)
247 | phi = 2 * np.arccos(q[:,0])
248 | denom = np.zeros_like(q[:,0])
249 | np.divide(1, np.sqrt(1 - q[:,0]**2), out=denom, where=q[:,0] != 1)
250 | axis = q[:,1:] * denom.reshape(-1,1)
251 | denom = np.linalg.norm(axis, axis=1).reshape(-1,1)
252 | a = np.zeros_like(axis)
253 | np.divide(phi.reshape(-1,1) * axis, denom, out=a, where=denom != 0)
254 | aa = a.astype(q.dtype)
255 | return aa.squeeze()
256 |
257 | def axisangle_apply(aa, x):
258 | # working only with single aa and single x at the moment
259 | xshape = x.shape
260 | aa = aa.reshape(3,)
261 | x = x.reshape(3,)
262 | phi = np.linalg.norm(aa)
263 | e = np.zeros_like(aa)
264 | np.divide(aa, phi, out=e, where=phi != 0)
265 | xr = np.cos(phi) * x + np.sin(phi) * np.cross(e, x) + (1 - np.cos(phi)) * (e.T @ x) * e
266 | return xr.reshape(xshape)
267 |
268 |
269 | def exp_so3(R):
270 | w = axisangle_from_rotm(R)
271 | return w
272 |
273 | def log_so3(w):
274 | R = rotm_from_axisangle(w)
275 | return R
276 |
277 | def exp_se3(R, t):
278 | R = R.reshape(-1,3,3)
279 | t = t.reshape(-1,3)
280 |
281 | w = exp_so3(R).reshape(-1,3)
282 |
283 | phi = np.linalg.norm(w, axis=1).reshape(-1,1,1)
284 | A = cross_prod_mat(w)
285 | Vi = np.eye(3, dtype=R.dtype) - A/2 + (1 - (phi * np.sin(phi) / (2 * (1 - np.cos(phi))))) / phi**2 * A @ A
286 | u = t.reshape(-1,1,3) @ Vi.transpose(0,2,1)
287 |
288 | # v = (u, w)
289 | v = np.empty((R.shape[0],6), dtype=R.dtype)
290 | v[:,:3] = u.squeeze()
291 | v[:,3:] = w
292 |
293 | return v.squeeze()
294 |
295 | def log_se3(v):
296 | # v = (u, w)
297 | v = v.reshape(-1,6)
298 | u = v[:,:3]
299 | w = v[:,3:]
300 |
301 | R = log_so3(w)
302 |
303 | phi = np.linalg.norm(w, axis=1).reshape(-1,1,1)
304 | A = cross_prod_mat(w)
305 | V = np.eye(3, dtype=v.dtype) + (1 - np.cos(phi)) / phi**2 * A + (phi - np.sin(phi)) / phi**3 * A @ A
306 | t = u.reshape(-1,1,3) @ V.transpose(0,2,1)
307 |
308 | return R.squeeze(), t.squeeze()
309 |
310 |
311 | def quat_from_rotm(R):
312 | R = R.reshape(-1,3,3)
313 | q = np.empty((R.shape[0], 4,), dtype=R.dtype)
314 | q[:,0] = np.sqrt( np.maximum(0, 1 + R[:,0,0] + R[:,1,1] + R[:,2,2]) )
315 | q[:,1] = np.sqrt( np.maximum(0, 1 + R[:,0,0] - R[:,1,1] - R[:,2,2]) )
316 | q[:,2] = np.sqrt( np.maximum(0, 1 - R[:,0,0] + R[:,1,1] - R[:,2,2]) )
317 | q[:,3] = np.sqrt( np.maximum(0, 1 - R[:,0,0] - R[:,1,1] + R[:,2,2]) )
318 | q[:,1] *= np.sign(q[:,1] * (R[:,2,1] - R[:,1,2]))
319 | q[:,2] *= np.sign(q[:,2] * (R[:,0,2] - R[:,2,0]))
320 | q[:,3] *= np.sign(q[:,3] * (R[:,1,0] - R[:,0,1]))
321 | q /= np.linalg.norm(q,axis=1,keepdims=True)
322 | return q.squeeze()
323 |
324 | def quat_from_axisangle(a):
325 | a = a.reshape(-1, 3)
326 | phi = np.linalg.norm(a, axis=1)
327 | iphi = np.zeros_like(phi)
328 | np.divide(1, phi, out=iphi, where=phi != 0)
329 | a = a * iphi.reshape(-1,1)
330 | theta = phi / 2.0
331 | r = np.cos(theta)
332 | stheta = np.sin(theta)
333 | q = np.stack((r, stheta*a[:,0], stheta*a[:,1], stheta*a[:,2]), axis=1)
334 | q /= np.linalg.norm(q, axis=1).reshape(-1,1)
335 | return q.squeeze()
336 |
337 | def quat_identity(n=1, dtype=np.float32):
338 | q = np.zeros((n,4), dtype=dtype)
339 | q[:,0] = 1
340 | return q.squeeze()
341 |
342 | def quat_conjugate(q):
343 | shape = q.shape
344 | q = q.reshape(-1,4).copy()
345 | q[:,1:] *= -1
346 | return q.reshape(shape)
347 |
348 | def quat_product(q1, q2):
349 | # q1 . q2 is equivalent to R(q1) @ R(q2)
350 | shape = q1.shape
351 | q1, q2 = q1.reshape(-1,4), q2.reshape(-1, 4)
352 | q = np.empty((max(q1.shape[0], q2.shape[0]), 4), dtype=q1.dtype)
353 | a1,b1,c1,d1 = q1[:,0], q1[:,1], q1[:,2], q1[:,3]
354 | a2,b2,c2,d2 = q2[:,0], q2[:,1], q2[:,2], q2[:,3]
355 | q[:,0] = a1 * a2 - b1 * b2 - c1 * c2 - d1 * d2
356 | q[:,1] = a1 * b2 + b1 * a2 + c1 * d2 - d1 * c2
357 | q[:,2] = a1 * c2 - b1 * d2 + c1 * a2 + d1 * b2
358 | q[:,3] = a1 * d2 + b1 * c2 - c1 * b2 + d1 * a2
359 | return q.squeeze()
360 |
361 | def quat_apply(q, x):
362 | xshape = x.shape
363 | x = x.reshape(-1, 3)
364 | qshape = q.shape
365 | q = q.reshape(-1, 4)
366 |
367 | p = np.empty((x.shape[0], 4), dtype=x.dtype)
368 | p[:,0] = 0
369 | p[:,1:] = x
370 |
371 | r = quat_product(quat_product(q, p), quat_conjugate(q))
372 | if r.ndim == 1:
373 | return r[1:].reshape(xshape)
374 | else:
375 | return r[:,1:].reshape(xshape)
376 |
377 |
378 | def quat_random(rng=None, n=1):
379 | # http://planning.cs.uiuc.edu/node198.html
380 | if rng is not None:
381 | u = rng.uniform(0, 1, size=(3,n))
382 | else:
383 | u = np.random.uniform(0, 1, size=(3,n))
384 | q = np.array((
385 | np.sqrt(1 - u[0]) * np.sin(2 * np.pi * u[1]),
386 | np.sqrt(1 - u[0]) * np.cos(2 * np.pi * u[1]),
387 | np.sqrt(u[0]) * np.sin(2 * np.pi * u[2]),
388 | np.sqrt(u[0]) * np.cos(2 * np.pi * u[2])
389 | )).T
390 | q /= np.linalg.norm(q,axis=1,keepdims=True)
391 | return q.squeeze()
392 |
393 | def quat_distance_angle(q0, q1):
394 | # https://math.stackexchange.com/questions/90081/quaternion-distance
395 | # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
396 | q0 = q0.reshape(-1,4)
397 | q1 = q1.reshape(-1,4)
398 | dists = np.arccos(np.clip(2 * np.sum(q0 * q1, axis=1)**2 - 1, -1, 1))
399 | return dists
400 |
401 | def quat_distance_normdiff(q0, q1):
402 | # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
403 | # \phi_4
404 | # [0, 1]
405 | q0 = q0.reshape(-1,4)
406 | q1 = q1.reshape(-1,4)
407 | return 1 - np.sum(q0 * q1, axis=1)**2
408 |
409 | def quat_distance_mineucl(q0, q1):
410 | # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
411 | # http://users.cecs.anu.edu.au/~trumpf/pubs/Hartley_Trumpf_Dai_Li.pdf
412 | q0 = q0.reshape(-1,4)
413 | q1 = q1.reshape(-1,4)
414 | diff0 = ((q0 - q1)**2).sum(axis=1)
415 | diff1 = ((q0 + q1)**2).sum(axis=1)
416 | return np.minimum(diff0, diff1)
417 |
418 | def quat_slerp_space(q0, q1, num=100, endpoint=True):
419 | q0 = q0.ravel()
420 | q1 = q1.ravel()
421 | dot = q0.dot(q1)
422 | if dot < 0:
423 | q1 *= -1
424 | dot *= -1
425 | t = np.linspace(0, 1, num=num, endpoint=endpoint, dtype=q0.dtype)
426 | t = t.reshape((-1,1))
427 | if dot > 0.9995:
428 | ret = q0 + t * (q1 - q0)
429 | return ret
430 | dot = np.clip(dot, -1, 1)
431 | theta0 = np.arccos(dot)
432 | theta = theta0 * t
433 | s0 = np.cos(theta) - dot * np.sin(theta) / np.sin(theta0)
434 | s1 = np.sin(theta) / np.sin(theta0)
435 | return (s0 * q0) + (s1 * q1)
436 |
437 | def cart_to_spherical(x):
438 | shape = x.shape
439 | x = x.reshape(-1,3)
440 | y = np.empty_like(x)
441 | y[:,0] = np.linalg.norm(x, axis=1) # r
442 | y[:,1] = np.arccos(x[:,2] / y[:,0]) # theta
443 | y[:,2] = np.arctan2(x[:,1], x[:,0]) # phi
444 | return y.reshape(shape)
445 |
446 | def spherical_to_cart(x):
447 | shape = x.shape
448 | x = x.reshape(-1,3)
449 | y = np.empty_like(x)
450 | y[:,0] = x[:,0] * np.sin(x[:,1]) * np.cos(x[:,2])
451 | y[:,1] = x[:,0] * np.sin(x[:,1]) * np.sin(x[:,2])
452 | y[:,2] = x[:,0] * np.cos(x[:,1])
453 | return y.reshape(shape)
454 |
455 | def spherical_random(r=1, n=1):
456 | # http://mathworld.wolfram.com/SpherePointPicking.html
457 | # https://math.stackexchange.com/questions/1585975/how-to-generate-random-points-on-a-sphere
458 | x = np.empty((n,3))
459 | x[:,0] = r
460 | x[:,1] = 2 * np.pi * np.random.uniform(0,1, size=(n,))
461 | x[:,2] = np.arccos(2 * np.random.uniform(0,1, size=(n,)) - 1)
462 | return x.squeeze()
463 |
464 | def color_pcl(pcl, K, im, color_axis=0, as_int=True, invalid_color=[0,0,0]):
465 | uvd = K @ pcl.T
466 | uvd /= uvd[2]
467 | uvd = np.round(uvd).astype(np.int32)
468 | mask = np.logical_and(uvd[0] >= 0, uvd[1] >= 0)
469 | color = np.empty((pcl.shape[0], 3), dtype=im.dtype)
470 | if color_axis == 0:
471 | mask = np.logical_and(mask, uvd[0] < im.shape[2])
472 | mask = np.logical_and(mask, uvd[1] < im.shape[1])
473 | uvd = uvd[:,mask]
474 | color[mask,:] = im[:,uvd[1],uvd[0]].T
475 | elif color_axis == 2:
476 | mask = np.logical_and(mask, uvd[0] < im.shape[1])
477 | mask = np.logical_and(mask, uvd[1] < im.shape[0])
478 | uvd = uvd[:,mask]
479 | color[mask,:] = im[uvd[1],uvd[0], :]
480 | else:
481 | raise Exception('invalid color_axis')
482 | color[np.logical_not(mask),:3] = invalid_color
483 | if as_int:
484 | color = (255.0 * color).astype(np.int32)
485 | return color
486 |
487 | def center_pcl(pcl, robust=False, copy=False, axis=1):
488 | if copy:
489 | pcl = pcl.copy()
490 | if robust:
491 | mu = np.median(pcl, axis=axis, keepdims=True)
492 | else:
493 | mu = np.mean(pcl, axis=axis, keepdims=True)
494 | return pcl - mu
495 |
496 | def to_homogeneous(x):
497 | # return np.hstack((x, np.ones((x.shape[0],1),dtype=x.dtype)))
498 | return np.concatenate((x, np.ones((*x.shape[:-1],1),dtype=x.dtype)), axis=-1)
499 |
500 | def from_homogeneous(x):
501 | return x[:,:-1] / x[:,-1]
502 |
503 | def project_uvn(uv, Ki=None):
504 | if uv.shape[1] == 2:
505 | uvn = to_homogeneous(uv)
506 | else:
507 | uvn = uv
508 | if uvn.shape[1] != 3:
509 | raise Exception('uv should have shape Nx2 or Nx3')
510 | if Ki is None:
511 | return uvn
512 | else:
513 | return uvn @ Ki.T
514 |
515 | def project_uvd(uv, depth, K=np.eye(3), R=np.eye(3), t=np.zeros((3,1)), ignore_negative_depth=True, return_uvn=False):
516 | Ki = np.linalg.inv(K)
517 |
518 | if ignore_negative_depth:
519 | mask = depth >= 0
520 | uv = uv[mask,:]
521 | d = depth[mask]
522 | else:
523 | d = depth.ravel()
524 |
525 | uv1 = to_homogeneous(uv)
526 |
527 | uvn1 = uv1 @ Ki.T
528 | xyz = d.reshape(-1,1) * uvn1
529 | xyz = (xyz - t.reshape((1,3))) @ R
530 |
531 | if return_uvn:
532 | return xyz, uvn1
533 | else:
534 | return xyz
535 |
536 | def project_depth(depth, K, R=np.eye(3,3), t=np.zeros((3,1)), ignore_negative_depth=True, return_uvn=False):
537 | u, v = np.meshgrid(range(depth.shape[1]), range(depth.shape[0]))
538 | uv = np.hstack((u.reshape(-1,1), v.reshape(-1,1)))
539 | return project_uvd(uv, depth.ravel(), K, R, t, ignore_negative_depth, return_uvn)
540 |
541 |
542 | def project_xyz(xyz, K=np.eye(3), R=np.eye(3,3), t=np.zeros((3,1))):
543 | uvd = K @ (R @ xyz.T + t.reshape((3,1)))
544 | uvd[:2] /= uvd[2]
545 | return uvd[:2].T, uvd[2]
546 |
547 |
548 | def relative_motion(R0, t0, R1, t1, Rt_from_global=True):
549 | t0 = t0.reshape((3,1))
550 | t1 = t1.reshape((3,1))
551 | if Rt_from_global:
552 | Rr = R1 @ R0.T
553 | tr = t1 - Rr @ t0
554 | else:
555 | Rr = R1.T @ R0
556 | tr = R1.T @ (t0 - t1)
557 | return Rr, tr.ravel()
558 |
559 |
560 | def translation_to_cameracenter(R, t):
561 | t = t.reshape(-1,3,1)
562 | R = R.reshape(-1,3,3)
563 | C = -R.transpose(0,2,1) @ t
564 | return C.squeeze()
565 |
566 | def cameracenter_to_translation(R, C):
567 | C = C.reshape(-1,3,1)
568 | R = R.reshape(-1,3,3)
569 | t = -R @ C
570 | return t.squeeze()
571 |
572 | def decompose_projection_matrix(P, return_t=True):
573 | if P.shape[0] != 3 or P.shape[1] != 4:
574 | raise Exception('P has to be 3x4')
575 | M = P[:, :3]
576 | C = -np.linalg.inv(M) @ P[:, 3:]
577 |
578 | R,K = np.linalg.qr(np.flipud(M).T)
579 | K = np.flipud(K.T)
580 | K = np.fliplr(K)
581 | R = np.flipud(R.T)
582 |
583 | T = np.diag(np.sign(np.diag(K)))
584 | K = K @ T
585 | R = T @ R
586 |
587 | if np.linalg.det(R) < 0:
588 | R *= -1
589 |
590 | K /= K[2,2]
591 | if return_t:
592 | return K, R, cameracenter_to_translation(R, C)
593 | else:
594 | return K, R, C
595 |
596 |
597 | def compose_projection_matrix(K=np.eye(3), R=np.eye(3,3), t=np.zeros((3,1))):
598 | return K @ np.hstack((R, t.reshape((3,1))))
599 |
600 |
601 |
602 | def point_plane_distance(pts, plane):
603 | pts = pts.reshape(-1,3)
604 | return np.abs(np.sum(plane[:3] * pts, axis=1) + plane[3]) / np.linalg.norm(plane[:3])
605 |
606 | def fit_plane(pts):
607 | pts = pts.reshape(-1,3)
608 | center = np.mean(pts, axis=0)
609 | A = pts - center
610 | u, s, vh = np.linalg.svd(A, full_matrices=False)
611 | plane = np.array([*vh[2], -vh[2].dot(center)])
612 | return plane
613 |
614 | def tetrahedron(dtype=np.float32):
615 | verts = np.array([
616 | (np.sqrt(8/9), 0, -1/3), (-np.sqrt(2/9), np.sqrt(2/3), -1/3),
617 | (-np.sqrt(2/9), -np.sqrt(2/3), -1/3), (0, 0, 1)], dtype=dtype)
618 | faces = np.array([(0,1,2), (0,2,3), (0,1,3), (1,2,3)], dtype=np.int32)
619 | normals = -np.mean(verts, axis=0) + verts
620 | normals /= np.linalg.norm(normals, axis=1).reshape(-1,1)
621 | return verts, faces, normals
622 |
623 | def cube(dtype=np.float32):
624 | verts = np.array([
625 | [-0.5,-0.5,-0.5], [-0.5,0.5,-0.5], [0.5,0.5,-0.5], [0.5,-0.5,-0.5],
626 | [-0.5,-0.5,0.5], [-0.5,0.5,0.5], [0.5,0.5,0.5], [0.5,-0.5,0.5]], dtype=dtype)
627 | faces = np.array([
628 | (0,1,2), (0,2,3), (4,5,6), (4,6,7),
629 | (0,4,7), (0,7,3), (1,5,6), (1,6,2),
630 | (3,2,6), (3,6,7), (0,1,5), (0,5,4)], dtype=np.int32)
631 | normals = -np.mean(verts, axis=0) + verts
632 | normals /= np.linalg.norm(normals, axis=1).reshape(-1,1)
633 | return verts, faces, normals
634 |
635 | def octahedron(dtype=np.float32):
636 | verts = np.array([
637 | (+1,0,0), (0,+1,0), (0,0,+1),
638 | (-1,0,0), (0,-1,0), (0,0,-1)], dtype=dtype)
639 | faces = np.array([
640 | (0,1,2), (1,2,3), (3,2,4), (4,2,0),
641 | (0,1,5), (1,5,3), (3,5,4), (4,5,0)], dtype=np.int32)
642 | normals = -np.mean(verts, axis=0) + verts
643 | normals /= np.linalg.norm(normals, axis=1).reshape(-1,1)
644 | return verts, faces, normals
645 |
646 | def icosahedron(dtype=np.float32):
647 | p = (1 + np.sqrt(5)) / 2
648 | verts = np.array([
649 | (-1,0,p), (1,0,p), (1,0,-p), (-1,0,-p),
650 | (0,-p,1), (0,p,1), (0,p,-1), (0,-p,-1),
651 | (-p,-1,0), (p,-1,0), (p,1,0), (-p,1,0)
652 | ], dtype=dtype)
653 | faces = np.array([
654 | (0,1,4), (0,1,5), (1,4,9), (1,9,10), (1,10,5), (0,4,8), (0,8,11), (0,11,5),
655 | (5,6,11), (5,6,10), (4,7,8), (4,7,9),
656 | (3,2,6), (3,2,7), (2,6,10), (2,10,9), (2,9,7), (3,6,11), (3,11,8), (3,8,7),
657 | ], dtype=np.int32)
658 | normals = -np.mean(verts, axis=0) + verts
659 | normals /= np.linalg.norm(normals, axis=1).reshape(-1,1)
660 | return verts, faces, normals
661 |
662 | def xyplane(dtype=np.float32, z=0, interleaved=False):
663 | if interleaved:
664 | eps = 1e-6
665 | verts = np.array([
666 | (-1,-1,z), (-1,1,z), (1,1,z),
667 | (1-eps,1,z), (1-eps,-1,z), (-1-eps,-1,z)], dtype=dtype)
668 | faces = np.array([(0,1,2), (3,4,5)], dtype=np.int32)
669 | else:
670 | verts = np.array([(-1,-1,z), (-1,1,z), (1,1,z), (1,-1,z)], dtype=dtype)
671 | faces = np.array([(0,1,2), (0,2,3)], dtype=np.int32)
672 | normals = np.zeros_like(verts)
673 | normals[:,2] = -1
674 | return verts, faces, normals
675 |
676 | def mesh_independent_verts(verts, faces, normals=None):
677 | new_verts = []
678 | new_normals = []
679 | for f in faces:
680 | new_verts.append(verts[f[0]])
681 | new_verts.append(verts[f[1]])
682 | new_verts.append(verts[f[2]])
683 | if normals is not None:
684 | new_normals.append(normals[f[0]])
685 | new_normals.append(normals[f[1]])
686 | new_normals.append(normals[f[2]])
687 | new_verts = np.array(new_verts)
688 | new_faces = np.arange(0, faces.size, dtype=faces.dtype).reshape(-1,3)
689 | if normals is None:
690 | return new_verts, new_faces
691 | else:
692 | new_normals = np.array(new_normals)
693 | return new_verts, new_faces, new_normals
694 |
695 |
696 | def stack_mesh(verts, faces):
697 | n_verts = 0
698 | mfaces = []
699 | for idx, f in enumerate(faces):
700 | mfaces.append(f + n_verts)
701 | n_verts += verts[idx].shape[0]
702 | verts = np.vstack(verts)
703 | faces = np.vstack(mfaces)
704 | return verts, faces
705 |
706 | def normalize_mesh(verts):
707 | # all the verts have unit distance to the center (0,0,0)
708 | return verts / np.linalg.norm(verts, axis=1, keepdims=True)
709 |
710 |
711 | def mesh_triangle_areas(verts, faces):
712 | a = verts[faces[:,0]]
713 | b = verts[faces[:,1]]
714 | c = verts[faces[:,2]]
715 | x = np.empty_like(a)
716 | x = a - b
717 | y = a - c
718 | t = np.empty_like(a)
719 | t[:,0] = (x[:,1] * y[:,2] - x[:,2] * y[:,1]);
720 | t[:,1] = (x[:,2] * y[:,0] - x[:,0] * y[:,2]);
721 | t[:,2] = (x[:,0] * y[:,1] - x[:,1] * y[:,0]);
722 | return np.linalg.norm(t, axis=1) / 2
723 |
724 | def subdivde_mesh(verts_in, faces_in, n=1):
725 | for iter in range(n):
726 | verts = []
727 | for v in verts_in:
728 | verts.append(v)
729 | faces = []
730 | verts_dict = {}
731 | for f in faces_in:
732 | f = np.sort(f)
733 | i0,i1,i2 = f
734 | v0,v1,v2 = verts_in[f]
735 |
736 | k = i0*len(verts_in)+i1
737 | if k in verts_dict:
738 | i01 = verts_dict[k]
739 | else:
740 | i01 = len(verts)
741 | verts_dict[k] = i01
742 | v01 = (v0 + v1) / 2
743 | verts.append(v01)
744 |
745 | k = i0*len(verts_in)+i2
746 | if k in verts_dict:
747 | i02 = verts_dict[k]
748 | else:
749 | i02 = len(verts)
750 | verts_dict[k] = i02
751 | v02 = (v0 + v2) / 2
752 | verts.append(v02)
753 |
754 | k = i1*len(verts_in)+i2
755 | if k in verts_dict:
756 | i12 = verts_dict[k]
757 | else:
758 | i12 = len(verts)
759 | verts_dict[k] = i12
760 | v12 = (v1 + v2) / 2
761 | verts.append(v12)
762 |
763 | faces.append((i0,i01,i02))
764 | faces.append((i01,i1,i12))
765 | faces.append((i12,i2,i02))
766 | faces.append((i01,i12,i02))
767 |
768 | verts_in = np.array(verts, dtype=verts_in.dtype)
769 | faces_in = np.array(faces, dtype=np.int32)
770 | return verts_in, faces_in
771 |
772 |
773 | def mesh_adjust_winding_order(verts, faces, normals):
774 | n0 = normals[faces[:,0]]
775 | n1 = normals[faces[:,1]]
776 | n2 = normals[faces[:,2]]
777 | fnormals = (n0 + n1 + n2) / 3
778 |
779 | v0 = verts[faces[:,0]]
780 | v1 = verts[faces[:,1]]
781 | v2 = verts[faces[:,2]]
782 |
783 | e0 = v1 - v0
784 | e1 = v2 - v0
785 | fn = np.cross(e0, e1)
786 |
787 | dot = np.sum(fnormals * fn, axis=1)
788 | ma = dot < 0
789 |
790 | nfaces = faces.copy()
791 | nfaces[ma,1], nfaces[ma,2] = nfaces[ma,2], nfaces[ma,1]
792 |
793 | return nfaces
794 |
795 |
796 | def pcl_to_shapecl(verts, colors=None, shape='cube', width=1.0):
797 | if shape == 'tetrahedron':
798 | cverts, cfaces, _ = tetrahedron()
799 | elif shape == 'cube':
800 | cverts, cfaces, _ = cube()
801 | elif shape == 'octahedron':
802 | cverts, cfaces, _ = octahedron()
803 | elif shape == 'icosahedron':
804 | cverts, cfaces, _ = icosahedron()
805 | else:
806 | raise Exception('invalid shape')
807 |
808 | sverts = np.tile(cverts, (verts.shape[0], 1))
809 | sverts *= width
810 | sverts += np.repeat(verts, cverts.shape[0], axis=0)
811 |
812 | sfaces = np.tile(cfaces, (verts.shape[0], 1))
813 | sfoffset = cverts.shape[0] * np.arange(0, verts.shape[0])
814 | sfaces += np.repeat(sfoffset, cfaces.shape[0]).reshape(-1,1)
815 |
816 | if colors is not None:
817 | scolors = np.repeat(colors, cverts.shape[0], axis=0)
818 | else:
819 | scolors = None
820 |
821 | return sverts, sfaces, scolors
822 |
--------------------------------------------------------------------------------
/co/gtimer.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # MIT License
5 | #
6 | # Copyright (c) 2019 autonomousvision
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 | import numpy as np
27 |
28 | from . import utils
29 |
30 | class StopWatch(utils.StopWatch):
31 | def __del__(self):
32 | print('='*80)
33 | print('gtimer:')
34 | total = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.sum).items()])
35 | print(f' [total] {total}')
36 | mean = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.mean).items()])
37 | print(f' [mean] {mean}')
38 | median = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.median).items()])
39 | print(f' [median] {median}')
40 | print('='*80)
41 |
42 | GTIMER = StopWatch()
43 |
44 | def start(name):
45 | GTIMER.start(name)
46 | def stop(name):
47 | GTIMER.stop(name)
48 |
49 | class Ctx(object):
50 | def __init__(self, name):
51 | self.name = name
52 |
53 | def __enter__(self):
54 | start(self.name)
55 |
56 | def __exit__(self, *args):
57 | stop(self.name)
58 |
--------------------------------------------------------------------------------
/co/io3d.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # MIT License
5 | #
6 | # Copyright (c) 2019 autonomousvision
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 | import struct
27 | import numpy as np
28 | import collections
29 |
30 | def _write_ply_point(fp, x,y,z, color=None, normal=None, binary=False):
31 | args = [x,y,z]
32 | if color is not None:
33 | args += [int(color[0]), int(color[1]), int(color[2])]
34 | if normal is not None:
35 | args += [normal[0],normal[1],normal[2]]
36 | if binary:
37 | fmt = ' 1:
101 | c = color[vidx]
102 | else:
103 | c = color[0]
104 | else:
105 | c = None
106 | if normals is None:
107 | n = None
108 | else:
109 | n = normals[vidx]
110 | _write_ply_point(fp, v[0],v[1],v[2], c, n, binary)
111 |
112 | if trias is not None:
113 | for t in trias:
114 | _write_ply_triangle(fp, t[0],t[1],t[2], binary)
115 |
116 | def faces_to_triangles(faces):
117 | new_faces = []
118 | for f in faces:
119 | if f[0] == 3:
120 | new_faces.append([f[1], f[2], f[3]])
121 | elif f[0] == 4:
122 | new_faces.append([f[1], f[2], f[3]])
123 | new_faces.append([f[3], f[4], f[1]])
124 | else:
125 | raise Exception('unknown face count %d', f[0])
126 | return new_faces
127 |
128 | def read_ply(path):
129 | with open(path, 'rb') as f:
130 | # parse header
131 | line = f.readline().decode().strip()
132 | if line != 'ply':
133 | raise Exception('Header error')
134 | n_verts = 0
135 | n_faces = 0
136 | vert_types = {}
137 | vert_bin_format = []
138 | vert_bin_len = 0
139 | vert_bin_cols = 0
140 | line = f.readline().decode()
141 | parse_vertex_prop = False
142 | while line.strip() != 'end_header':
143 | if 'format' in line:
144 | if 'ascii' in line:
145 | binary = False
146 | elif 'binary_little_endian' in line:
147 | binary = True
148 | else:
149 | raise Exception('invalid ply format')
150 | if 'element face' in line:
151 | splits = line.strip().split(' ')
152 | n_faces = int(splits[-1])
153 | parse_vertex_prop = False
154 | if 'element camera' in line:
155 | parse_vertex_prop = False
156 | if 'element vertex' in line:
157 | splits = line.strip().split(' ')
158 | n_verts = int(splits[-1])
159 | parse_vertex_prop = True
160 | if parse_vertex_prop and 'property' in line:
161 | prop = line.strip().split()
162 | if prop[1] == 'float':
163 | vert_bin_format.append('f4')
164 | vert_bin_len += 4
165 | vert_bin_cols += 1
166 | elif prop[1] == 'uchar':
167 | vert_bin_format.append('B')
168 | vert_bin_len += 1
169 | vert_bin_cols += 1
170 | else:
171 | raise Exception('invalid property')
172 | vert_types[prop[2]] = len(vert_types)
173 | line = f.readline().decode()
174 |
175 | # parse content
176 | if binary:
177 | sz = n_verts * vert_bin_len
178 | fmt = ','.join(vert_bin_format)
179 | verts = np.ndarray(shape=(1, n_verts), dtype=np.dtype(fmt), buffer=f.read(sz))
180 | verts = verts[0].astype(vert_bin_cols*'f4,').view(dtype='f4').reshape((n_verts,-1))
181 | faces = []
182 | for idx in range(n_faces):
183 | fmt = '= 2 and len(parts[1]) > 0:
223 | tidx = int(parts[1]) - 1
224 | else:
225 | tidx = -1
226 | if len(parts) >= 3 and len(parts[2]) > 0:
227 | nidx = int(parts[2]) - 1
228 | else:
229 | nidx = -1
230 | return vidx, tidx, nidx
231 |
232 | def read_obj(path):
233 | with open(path, 'r') as fp:
234 | lines = fp.readlines()
235 |
236 | verts = []
237 | colors = []
238 | fnorms = []
239 | fnorm_map = collections.defaultdict(list)
240 | faces = []
241 | for line in lines:
242 | line = line.strip()
243 | if line.startswith('#') or len(line) == 0:
244 | continue
245 |
246 | parts = line.split()
247 | if line.startswith('v '):
248 | parts = parts[1:]
249 | x,y,z = float(parts[0]), float(parts[1]), float(parts[2])
250 | if len(parts) == 4 or len(parts) == 7:
251 | w = float(parts[3])
252 | x,y,z = x/w, y/w, z/w
253 | verts.append((x,y,z))
254 | if len(parts) >= 6:
255 | r,g,b = float(parts[-3]), float(parts[-2]), float(parts[-1])
256 | rgb.append((r,g,b))
257 |
258 | elif line.startswith('vn '):
259 | parts = parts[1:]
260 | x,y,z = float(parts[0]), float(parts[1]), float(parts[2])
261 | fnorms.append((x,y,z))
262 |
263 | elif line.startswith('f '):
264 | parts = parts[1:]
265 | if len(parts) != 3:
266 | raise Exception('only triangle meshes supported atm')
267 | vidx0, tidx0, nidx0 = _read_obj_split_f(parts[0])
268 | vidx1, tidx1, nidx1 = _read_obj_split_f(parts[1])
269 | vidx2, tidx2, nidx2 = _read_obj_split_f(parts[2])
270 |
271 | faces.append((vidx0, vidx1, vidx2))
272 | if nidx0 >= 0:
273 | fnorm_map[vidx0].append( nidx0 )
274 | if nidx1 >= 0:
275 | fnorm_map[vidx1].append( nidx1 )
276 | if nidx2 >= 0:
277 | fnorm_map[vidx2].append( nidx2 )
278 |
279 | verts = np.array(verts)
280 | colors = np.array(colors)
281 | fnorms = np.array(fnorms)
282 | faces = np.array(faces)
283 |
284 | # face normals to vertex normals
285 | norms = np.zeros_like(verts)
286 | for vidx in fnorm_map.keys():
287 | ind = fnorm_map[vidx]
288 | norms[vidx] = fnorms[ind].sum(axis=0)
289 | N = np.linalg.norm(norms, axis=1, keepdims=True)
290 | np.divide(norms, N, out=norms, where=N != 0)
291 |
292 | return verts, faces, colors, norms
293 |
--------------------------------------------------------------------------------
/co/metric.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # MIT License
5 | #
6 | # Copyright (c) 2019 autonomousvision
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 | import numpy as np
27 | from . import geometry
28 |
29 | def _process_inputs(estimate, target, mask):
30 | if estimate.shape != target.shape:
31 | raise Exception('estimate and target have to be same shape')
32 | if mask is None:
33 | mask = np.ones(estimate.shape, dtype=np.bool)
34 | else:
35 | mask = mask != 0
36 | if estimate.shape != mask.shape:
37 | raise Exception('estimate and mask have to be same shape')
38 | return estimate, target, mask
39 |
40 | def mse(estimate, target, mask=None):
41 | estimate, target, mask = _process_inputs(estimate, target, mask)
42 | m = np.sum((estimate[mask] - target[mask])**2) / mask.sum()
43 | return m
44 |
45 | def rmse(estimate, target, mask=None):
46 | return np.sqrt(mse(estimate, target, mask))
47 |
48 | def mae(estimate, target, mask=None):
49 | estimate, target, mask = _process_inputs(estimate, target, mask)
50 | m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum()
51 | return m
52 |
53 | def outlier_fraction(estimate, target, mask=None, threshold=0):
54 | estimate, target, mask = _process_inputs(estimate, target, mask)
55 | diff = np.abs(estimate[mask] - target[mask])
56 | m = (diff > threshold).sum() / mask.sum()
57 | return m
58 |
59 |
60 | class Metric(object):
61 | def __init__(self, str_prefix=''):
62 | self.str_prefix = str_prefix
63 | self.reset()
64 |
65 | def reset(self):
66 | pass
67 |
68 | def add(self, es, ta, ma=None):
69 | pass
70 |
71 | def get(self):
72 | return {}
73 |
74 | def items(self):
75 | return self.get().items()
76 |
77 | def __str__(self):
78 | return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()])
79 |
80 | class MultipleMetric(Metric):
81 | def __init__(self, *metrics, **kwargs):
82 | self.metrics = [*metrics]
83 | super().__init__(**kwargs)
84 |
85 | def reset(self):
86 | for m in self.metrics:
87 | m.reset()
88 |
89 | def add(self, es, ta, ma=None):
90 | for m in self.metrics:
91 | m.add(es, ta, ma)
92 |
93 | def get(self):
94 | ret = {}
95 | for m in self.metrics:
96 | vals = m.get()
97 | for k in vals:
98 | ret[k] = vals[k]
99 | return ret
100 |
101 | def __str__(self):
102 | return '\n'.join([str(m) for m in self.metrics])
103 |
104 | class BaseDistanceMetric(Metric):
105 | def __init__(self, name='', **kwargs):
106 | super().__init__(**kwargs)
107 | self.name = name
108 |
109 | def reset(self):
110 | self.dists = []
111 |
112 | def add(self, es, ta, ma=None):
113 | pass
114 |
115 | def get(self):
116 | dists = np.hstack(self.dists)
117 | return {
118 | f'dist{self.name}_mean': float(np.mean(dists)),
119 | f'dist{self.name}_std': float(np.std(dists)),
120 | f'dist{self.name}_median': float(np.median(dists)),
121 | f'dist{self.name}_q10': float(np.percentile(dists, 10)),
122 | f'dist{self.name}_q90': float(np.percentile(dists, 90)),
123 | f'dist{self.name}_min': float(np.min(dists)),
124 | f'dist{self.name}_max': float(np.max(dists)),
125 | }
126 |
127 | class DistanceMetric(BaseDistanceMetric):
128 | def __init__(self, vec_length, p=2, **kwargs):
129 | super().__init__(name=f'{p}', **kwargs)
130 | self.vec_length = vec_length
131 | self.p = p
132 |
133 | def add(self, es, ta, ma=None):
134 | if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
135 | print(es.shape, ta.shape)
136 | raise Exception('es and ta have to be of shape Nxdim')
137 | if ma is not None:
138 | es = es[ma != 0]
139 | ta = ta[ma != 0]
140 | dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
141 | self.dists.append( dist )
142 |
143 | class OutlierFractionMetric(DistanceMetric):
144 | def __init__(self, thresholds, *args, **kwargs):
145 | super().__init__(*args, **kwargs)
146 | self.thresholds = thresholds
147 |
148 | def get(self):
149 | dists = np.hstack(self.dists)
150 | ret = {}
151 | for t in self.thresholds:
152 | ma = dists > t
153 | ret[f'of{t}'] = float(ma.sum() / ma.size)
154 | return ret
155 |
156 | class RelativeDistanceMetric(BaseDistanceMetric):
157 | def __init__(self, vec_length, p=2, **kwargs):
158 | super().__init__(name=f'rel{p}', **kwargs)
159 | self.vec_length = vec_length
160 | self.p = p
161 |
162 | def add(self, es, ta, ma=None):
163 | if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
164 | raise Exception('es and ta have to be of shape Nxdim')
165 | dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
166 | denom = np.linalg.norm(ta, ord=self.p, axis=1)
167 | dist /= denom
168 | if ma is not None:
169 | dist = dist[ma != 0]
170 | self.dists.append( dist )
171 |
172 | class RotmDistanceMetric(BaseDistanceMetric):
173 | def __init__(self, type='identity', **kwargs):
174 | super().__init__(name=type, **kwargs)
175 | self.type = type
176 |
177 | def add(self, es, ta, ma=None):
178 | if es.shape != ta.shape or es.shape[1] != 3 or es.shape[2] != 3 or es.ndim != 3:
179 | print(es.shape, ta.shape)
180 | raise Exception('es and ta have to be of shape Nx3x3')
181 | if ma is not None:
182 | raise Exception('mask is not implemented')
183 | if self.type == 'identity':
184 | self.dists.append( geometry.rotm_distance_identity(es, ta) )
185 | elif self.type == 'geodesic':
186 | self.dists.append( geometry.rotm_distance_geodesic_unit_sphere(es, ta) )
187 | else:
188 | raise Exception('invalid distance type')
189 |
190 | class QuaternionDistanceMetric(BaseDistanceMetric):
191 | def __init__(self, type='angle', **kwargs):
192 | super().__init__(name=type, **kwargs)
193 | self.type = type
194 |
195 | def add(self, es, ta, ma=None):
196 | if es.shape != ta.shape or es.shape[1] != 4 or es.ndim != 2:
197 | print(es.shape, ta.shape)
198 | raise Exception('es and ta have to be of shape Nx4')
199 | if ma is not None:
200 | raise Exception('mask is not implemented')
201 | if self.type == 'angle':
202 | self.dists.append( geometry.quat_distance_angle(es, ta) )
203 | elif self.type == 'mineucl':
204 | self.dists.append( geometry.quat_distance_mineucl(es, ta) )
205 | elif self.type == 'normdiff':
206 | self.dists.append( geometry.quat_distance_normdiff(es, ta) )
207 | else:
208 | raise Exception('invalid distance type')
209 |
210 |
211 | class BinaryAccuracyMetric(Metric):
212 | def __init__(self, thresholds=np.linspace(0.0, 1.0, num=101, dtype=np.float64)[:-1], **kwargs):
213 | self.thresholds = thresholds
214 | super().__init__(**kwargs)
215 |
216 | def reset(self):
217 | self.tps = [0 for wp in self.thresholds]
218 | self.fps = [0 for wp in self.thresholds]
219 | self.fns = [0 for wp in self.thresholds]
220 | self.tns = [0 for wp in self.thresholds]
221 | self.n_pos = 0
222 | self.n_neg = 0
223 |
224 | def add(self, es, ta, ma=None):
225 | if ma is not None:
226 | raise Exception('mask is not implemented')
227 | es = es.ravel()
228 | ta = ta.ravel()
229 | if es.shape[0] != ta.shape[0]:
230 | raise Exception('invalid shape of es, or ta')
231 | if es.min() < 0 or es.max() > 1:
232 | raise Exception('estimate has wrong value range')
233 | ta_p = (ta == 1)
234 | ta_n = (ta == 0)
235 | es_p = es[ta_p]
236 | es_n = es[ta_n]
237 | for idx, wp in enumerate(self.thresholds):
238 | wp = np.asscalar(wp)
239 | self.tps[idx] += (es_p > wp).sum()
240 | self.fps[idx] += (es_n > wp).sum()
241 | self.fns[idx] += (es_p <= wp).sum()
242 | self.tns[idx] += (es_n <= wp).sum()
243 | self.n_pos += ta_p.sum()
244 | self.n_neg += ta_n.sum()
245 |
246 | def get(self):
247 | tps = np.array(self.tps).astype(np.float32)
248 | fps = np.array(self.fps).astype(np.float32)
249 | fns = np.array(self.fns).astype(np.float32)
250 | tns = np.array(self.tns).astype(np.float32)
251 | wp = self.thresholds
252 |
253 | ret = {}
254 |
255 | precisions = np.divide(tps, tps + fps, out=np.zeros_like(tps), where=tps + fps != 0)
256 | recalls = np.divide(tps, tps + fns, out=np.zeros_like(tps), where=tps + fns != 0) # tprs
257 | fprs = np.divide(fps, fps + tns, out=np.zeros_like(tps), where=fps + tns != 0)
258 |
259 | precisions = np.r_[0, precisions, 1]
260 | recalls = np.r_[1, recalls, 0]
261 | fprs = np.r_[1, fprs, 0]
262 |
263 | ret['auc'] = float(-np.trapz(recalls, fprs))
264 | ret['prauc'] = float(-np.trapz(precisions, recalls))
265 | ret['ap'] = float(-(np.diff(recalls) * precisions[:-1]).sum())
266 |
267 | accuracies = np.divide(tps + tns, tps + tns + fps + fns)
268 | aacc = np.mean(accuracies)
269 | for t in np.linspace(0,1,num=11)[1:-1]:
270 | idx = np.argmin(np.abs(t - wp))
271 | ret[f'acc{wp[idx]:.2f}'] = float(accuracies[idx])
272 |
273 | return ret
274 |
--------------------------------------------------------------------------------
/co/utils.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # MIT License
5 | #
6 | # Copyright (c) 2019 autonomousvision
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 | import numpy as np
27 | import time
28 | from collections import OrderedDict
29 | import argparse
30 | import subprocess
31 |
32 | def str2bool(v):
33 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
34 | return True
35 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
36 | return False
37 | else:
38 | raise argparse.ArgumentTypeError('Boolean value expected.')
39 |
40 | class StopWatch(object):
41 | def __init__(self):
42 | self.timings = OrderedDict()
43 | self.starts = {}
44 |
45 | def start(self, name):
46 | self.starts[name] = time.time()
47 |
48 | def stop(self, name):
49 | if name not in self.timings:
50 | self.timings[name] = []
51 | self.timings[name].append(time.time() - self.starts[name])
52 |
53 | def get(self, name=None, reduce=np.sum):
54 | if name is not None:
55 | return reduce(self.timings[name])
56 | else:
57 | ret = {}
58 | for k in self.timings:
59 | ret[k] = reduce(self.timings[k])
60 | return ret
61 |
62 | def __repr__(self):
63 | return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
64 | def __str__(self):
65 | return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
66 |
67 | class ETA(object):
68 | def __init__(self, length):
69 | self.length = length
70 | self.start_time = time.time()
71 | self.current_idx = 0
72 | self.current_time = time.time()
73 |
74 | def update(self, idx):
75 | self.current_idx = idx
76 | self.current_time = time.time()
77 |
78 | def get_elapsed_time(self):
79 | return self.current_time - self.start_time
80 |
81 | def get_item_time(self):
82 | return self.get_elapsed_time() / (self.current_idx + 1)
83 |
84 | def get_remaining_time(self):
85 | return self.get_item_time() * (self.length - self.current_idx + 1)
86 |
87 | def format_time(self, seconds):
88 | minutes, seconds = divmod(seconds, 60)
89 | hours, minutes = divmod(minutes, 60)
90 | hours = int(hours)
91 | minutes = int(minutes)
92 | return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
93 |
94 | def get_elapsed_time_str(self):
95 | return self.format_time(self.get_elapsed_time())
96 |
97 | def get_remaining_time_str(self):
98 | return self.format_time(self.get_remaining_time())
99 |
100 | def git_hash(cwd=None):
101 | ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
102 | hash = ret.stdout
103 | if hash is not None and 'fatal' not in hash.decode():
104 | return hash.decode().strip()
105 | else:
106 | return None
107 |
108 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "OUTPUT_DIR": "/path/to/output/dir",
3 | "DATA_DIR": "/path/to/data/dir",
4 | "SHAPENET_DIR": "/path/to/ShapeNetCore.v2",
5 | "CTD_DIR": "/path/to/connecting_the_dots",
6 | "LITEFLOWNET_DIR": "/path/to/pytorch-liteflownet"
7 | }
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # Copyright (c) 2021 ams International AG
5 | #
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
7 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation
8 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
9 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
10 | #
11 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
14 | # OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
15 | # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
16 | # IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # MIT License
5 | #
6 | # Copyright (c) 2019 autonomousvision
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 | import torch.utils.data
27 | import numpy as np
28 |
29 | class TestSet(object):
30 | def __init__(self, name, dset, test_frequency=1):
31 | self.name = name
32 | self.dset = dset
33 | self.test_frequency = test_frequency
34 |
35 | class TestSets(list):
36 | def append(self, name, dset, test_frequency=1):
37 | super().append(TestSet(name, dset, test_frequency))
38 |
39 |
40 |
41 | class MultiDataset(torch.utils.data.Dataset):
42 | def __init__(self, *datasets):
43 | self.current_epoch = 0
44 |
45 | self.datasets = []
46 | self.cum_n_samples = [0]
47 |
48 | for dataset in datasets:
49 | self.append(dataset)
50 |
51 | def append(self, dataset):
52 | self.datasets.append(dataset)
53 | self.__update_cum_n_samples(dataset)
54 |
55 | def __update_cum_n_samples(self, dataset):
56 | n_samples = self.cum_n_samples[-1] + len(dataset)
57 | self.cum_n_samples.append(n_samples)
58 |
59 | def dataset_updated(self):
60 | self.cum_n_samples = [0]
61 | for dset in self.datasets:
62 | self.__update_cum_n_samples(dset)
63 |
64 | def __len__(self):
65 | return self.cum_n_samples[-1]
66 |
67 | def __getitem__(self, idx):
68 | didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1
69 | sidx = idx - self.cum_n_samples[didx]
70 | return self.datasets[didx][sidx]
71 |
72 |
73 |
74 | class BaseDataset(torch.utils.data.Dataset):
75 | def __init__(self, train=True, fix_seed_per_epoch=False):
76 | self.current_epoch = 0
77 | self.train = train
78 | self.fix_seed_per_epoch = fix_seed_per_epoch
79 |
80 | def get_rng(self, idx):
81 | rng = np.random.RandomState()
82 | if self.train:
83 | if self.fix_seed_per_epoch:
84 | seed = 1 * len(self) + idx
85 | else:
86 | seed = (self.current_epoch + 1) * len(self) + idx
87 | rng.seed(seed)
88 | else:
89 | rng.seed(idx)
90 | return rng
91 |
--------------------------------------------------------------------------------
/data/create_syn_data.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | #
5 | # MIT License
6 | #
7 | # Copyright (c) 2019 autonomousvision
8 | #
9 | # Permission is hereby granted, free of charge, to any person obtaining a copy
10 | # of this software and associated documentation files (the "Software"), to deal
11 | # in the Software without restriction, including without limitation the rights
12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 | # copies of the Software, and to permit persons to whom the Software is
14 | # furnished to do so, subject to the following conditions:
15 | #
16 | # The above copyright notice and this permission notice shall be included in all
17 | # copies or substantial portions of the Software.
18 | #
19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 | # SOFTWARE.
26 | #
27 | #
28 | # MIT License
29 | #
30 | # Copyright, 2021 ams International AG
31 | #
32 | # Permission is hereby granted, free of charge, to any person obtaining a copy
33 | # of this software and associated documentation files (the "Software"), to deal
34 | # in the Software without restriction, including without limitation the rights
35 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
36 | # copies of the Software, and to permit persons to whom the Software is
37 | # furnished to do so, subject to the following conditions:
38 | #
39 | # The above copyright notice and this permission notice shall be included in all
40 | # copies or substantial portions of the Software.
41 | #
42 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
43 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
44 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
45 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
46 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
47 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
48 | # SOFTWARE.
49 |
50 | import argparse
51 |
52 | import numpy as np
53 | import pickle
54 | from pathlib import Path
55 | import time
56 | import json
57 | import cv2
58 | import collections
59 | import sys
60 | import h5py
61 | import os
62 |
63 | sys.path.append('../')
64 | import co
65 | from data_manipulation import get_rotation_matrix, read_pattern_file, post_process
66 |
67 | config_path = os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..', 'config.json'))
68 | with open(config_path) as fp:
69 | config = json.load(fp)
70 | CTD_path = Path(config['CTD_DIR'])
71 | sys.path.append(str(CTD_path / 'data'))
72 | sys.path.append(str(CTD_path / 'renderer'))
73 |
74 | from lcn import lcn
75 | from cyrender import PyRenderInput, PyCamera, PyShader, PyRenderer
76 |
77 | def get_objs(shapenet_dir, obj_classes, num_perclass=100):
78 | shapenet = {'chair': '03001627',
79 | 'airplane': '02691156',
80 | 'car': '02958343',
81 | 'watercraft': '04530566'}
82 |
83 | obj_paths = []
84 | for cls in obj_classes:
85 | if cls not in shapenet.keys():
86 | raise Exception('unknown class name')
87 | ids = shapenet[cls]
88 | obj_path = sorted(Path(f'{shapenet_dir}/{ids}').glob('**/models/*.obj'))
89 | obj_paths += obj_path[:num_perclass]
90 | print(f'found {len(obj_paths)} object paths')
91 |
92 | objs = []
93 | for obj_path in obj_paths:
94 | print(f'load {obj_path}')
95 | v, f, _, n = co.io3d.read_obj(obj_path)
96 | diffs = v.max(axis=0) - v.min(axis=0)
97 | v /= (0.5 * diffs.max())
98 | v -= (v.min(axis=0) + 1)
99 | f = f.astype(np.int32)
100 | objs.append((v, f, n))
101 | print(f'loaded {len(objs)} objects')
102 |
103 | return objs
104 |
105 |
106 | def get_mesh(rng, min_z=0):
107 | # set up background board
108 | verts, faces, normals, colors = [], [], [], []
109 | v, f, n = co.geometry.xyplane(z=0, interleaved=True)
110 | v[:, 2] += -v[:, 2].min() + rng.uniform(3, 5)
111 | v[:, :2] *= 5e2
112 | v[:, 2] = np.mean(v[:, 2]) + (v[:, 2] - np.mean(v[:, 2])) * 5e2
113 | c = np.empty_like(v)
114 | c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32)
115 | verts.append(v)
116 | faces.append(f)
117 | normals.append(n)
118 | colors.append(c)
119 |
120 | # randomly sample 4 foreground objects for each scene
121 | for shape_idx in range(4):
122 | v, f, n = objs[rng.randint(0, len(objs))]
123 | v, f, n = v.copy(), f.copy(), n.copy()
124 |
125 | s = rng.uniform(0.25, 1)
126 | v *= s
127 | R = co.geometry.rotm_from_quat(co.geometry.quat_random(rng=rng))
128 | v = v @ R.T
129 | n = n @ R.T
130 | v[:, 2] += -v[:, 2].min() + min_z + rng.uniform(0.5, 3)
131 | v[:, :2] += rng.uniform(-1, 1, size=(1, 2))
132 |
133 | c = np.empty_like(v)
134 | c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32)
135 |
136 | verts.append(v.astype(np.float32))
137 | faces.append(f)
138 | normals.append(n)
139 | colors.append(c)
140 |
141 | verts, faces = co.geometry.stack_mesh(verts, faces)
142 | normals = np.vstack(normals).astype(np.float32)
143 | colors = np.vstack(colors).astype(np.float32)
144 | return verts, faces, colors, normals
145 |
146 |
147 | def create_data(pattern_type, out_root, idx, n_samples, imsize_proj, imsize, pattern, K_proj, K, K_processed, baseline, blend_im, noise,
148 | track_length=4):
149 | tic = time.time()
150 | rng = np.random.RandomState()
151 |
152 | rng.seed(idx)
153 |
154 | verts, faces, colors, normals = get_mesh(rng)
155 | data = PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy())
156 | print(f'loading mesh for sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
157 |
158 | # let the camera point to the center
159 | center = np.array([0, 0, 3], dtype=np.float32)
160 |
161 | basevec = np.array([-baseline, 0, 0], dtype=np.float32)
162 | unit = np.array([0, 0, 1], dtype=np.float32)
163 |
164 | cam_x_ = rng.uniform(-0.2, 0.2)
165 | cam_y_ = rng.uniform(-0.2, 0.2)
166 | cam_z_ = rng.uniform(-0.2, 0.2)
167 |
168 | ret = collections.defaultdict(list)
169 | blend_im_rnd = np.clip(blend_im + rng.uniform(-0.1, 0.1), 0, 1)
170 |
171 | # capture the same static scene from different view points as a track
172 | for ind in range(track_length):
173 |
174 | cam_x = cam_x_ + rng.uniform(-0.1, 0.1)
175 | cam_y = cam_y_ + rng.uniform(-0.1, 0.1)
176 | cam_z = cam_z_ + rng.uniform(-0.1, 0.1)
177 |
178 | tcam = np.array([cam_x, cam_y, cam_z], dtype=np.float32)
179 |
180 | if np.linalg.norm(tcam[0:2]) < 1e-9:
181 | Rcam = np.eye(3, dtype=np.float32)
182 | else:
183 | Rcam = get_rotation_matrix(center, center - tcam)
184 |
185 | tproj = tcam + basevec
186 | Rproj = Rcam
187 |
188 | ret['R'].append(Rcam)
189 | ret['t'].append(tcam)
190 |
191 | fx_proj = K_proj[0, 0]
192 | fy_proj = K_proj[1, 1]
193 | px_proj = K_proj[0, 2]
194 | py_proj = K_proj[1, 2]
195 | im_height_proj = imsize_proj[0]
196 | im_width_proj = imsize_proj[1]
197 | proj = PyCamera(fx_proj, fy_proj, px_proj, py_proj, Rproj, tproj, im_width_proj, im_height_proj)
198 |
199 | fx = K[0, 0]
200 | fy = K[1, 1]
201 | px = K[0, 2]
202 | py = K[1, 2]
203 | im_height = imsize[0]
204 | im_width = imsize[1]
205 | cam = PyCamera(fx, fy, px, py, Rcam, tcam, im_width, im_height)
206 |
207 | shader = PyShader(0.5, 1.5, 0.0, 10)
208 | pyrenderer = PyRenderer(cam, shader, engine='gpu')
209 | if args.pattern_type == 'default':
210 | pyrenderer.mesh_proj(data, proj, pattern.copy(), d_alpha=0, d_beta=0.0)
211 | else:
212 | pyrenderer.mesh_proj(data, proj, pattern.copy(), d_alpha=0, d_beta=0.35)
213 |
214 | # get the reflected laser pattern $R$
215 | im = pyrenderer.color().copy()
216 | im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)
217 |
218 | focal_length = K_processed[0, 0]
219 | depth = pyrenderer.depth().copy()
220 | disp = baseline * focal_length / depth
221 |
222 | # get the ambient image $A$
223 | ambient = pyrenderer.normal().copy()
224 | ambient = cv2.cvtColor(ambient, cv2.COLOR_RGB2GRAY)
225 |
226 | # get the noise free IR image $J$
227 | im = blend_im_rnd * im + (1 - blend_im_rnd) * ambient
228 | ret['ambient'].append(post_process(pattern_type, ambient)[None].astype(np.float32))
229 |
230 | # get the gradient magnitude of the ambient image $|\nabla A|$
231 | ambient = ambient.astype(np.float32)
232 | sobelx = cv2.Sobel(ambient, cv2.CV_32F, 1, 0, ksize=5)
233 | sobely = cv2.Sobel(ambient, cv2.CV_32F, 0, 1, ksize=5)
234 | grad = np.sqrt(sobelx ** 2 + sobely ** 2)
235 | grad = np.maximum(grad - 0.8, 0.0) # parameter
236 |
237 | # get the local contract normalized grad LCN($|\nabla A|$)
238 | grad_lcn, grad_std = lcn.normalize(grad, 5, 0.1)
239 | grad_lcn = np.clip(grad_lcn, 0.0, 1.0) # parameter
240 | ret['grad'].append(post_process(pattern_type, grad_lcn)[None].astype(np.float32))
241 |
242 | ret['im'].append(post_process(pattern_type, im)[None].astype(np.float32))
243 | ret['disp'].append(post_process(pattern_type, disp)[None].astype(np.float32))
244 |
245 | for key in ret.keys():
246 | ret[key] = np.stack(ret[key], axis=0)
247 |
248 | # save to files
249 | out_dir = out_root / f'{idx:08d}'
250 | out_dir.mkdir(exist_ok=True, parents=True)
251 | out_path = out_dir / f'frames.hdf5'
252 | with h5py.File(out_path, "w") as f:
253 | for k, val in ret.items():
254 | # f.create_dataset(k, data=val, compression="lzf")
255 | f.create_dataset(k, data=val)
256 |
257 | print(f'create sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
258 |
259 |
260 | if __name__ == '__main__':
261 |
262 | parser = argparse.ArgumentParser()
263 | parser.add_argument('pattern_type',
264 | help='Select the pattern file for projecting dots',
265 | default='default',
266 | choices=['default', 'kinect', 'real'], type=str)
267 | args = parser.parse_args()
268 |
269 | np.random.seed(42)
270 |
271 | # output directory
272 | config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'config.json'))
273 | with open(config_path) as fp:
274 | config = json.load(fp)
275 | data_root = Path(config['DATA_DIR'])
276 | shapenet_root = config['SHAPENET_DIR']
277 |
278 | out_root = data_root
279 | out_root.mkdir(parents=True, exist_ok=True)
280 |
281 | # load shapenet models
282 | obj_classes = ['chair']
283 | objs = get_objs(shapenet_root, obj_classes)
284 |
285 | # camera parameters
286 | if args.pattern_type == 'real':
287 | fl_proj = 1112.1806640625
288 | fl = 1112.1806640625
289 | imsize_proj = (1280, 1080)
290 | imsize = (1280, 1080)
291 | imsize_processed = (512, 432)
292 | K_proj = np.array([[fl_proj, 0, 517.0896606445312], [0, fl_proj, 649.6329956054688], [0, 0, 1]], dtype=np.float32)
293 | K = np.array([[fl, 0, 517.0896606445312], [0, fl, 649.6329956054688], [0, 0, 1]], dtype=np.float32)
294 | baseline = 0.0246
295 | blend_im = 0.6
296 | noise = 0
297 | else:
298 | fl_proj = 1582.06005876
299 | fl = 435.2
300 | imsize_proj = (4096, 4096)
301 | imsize = (512, 432)
302 | imsize_processed = (512, 432)
303 | K_proj = np.array([[fl_proj, 0, 2047.5], [0, fl_proj, 2047.5], [0, 0, 1]], dtype=np.float32)
304 | K = np.array([[fl, 0, 216], [0, fl, 256], [0, 0, 1]], dtype=np.float32)
305 | baseline = 0.025
306 | blend_im = 0.6
307 | noise = 0
308 |
309 | # capture the same static scene from different view points as a track
310 | track_length = 4
311 |
312 | # load pattern image
313 | pattern = read_pattern_file(args.pattern_type, imsize_proj)
314 |
315 | x_cam = np.arange(0, imsize[1])
316 | y_cam = np.arange(0, imsize[0])
317 | x_mesh, y_mesh = np.meshgrid(x_cam, y_cam)
318 | x_mesh_f = np.reshape(x_mesh, [-1])
319 | y_mesh_f = np.reshape(y_mesh, [-1])
320 |
321 | grid_points = np.stack([x_mesh_f, y_mesh_f, np.ones_like(x_mesh_f)], axis=0)
322 | grid_points_mapped = K_proj.dot(np.linalg.inv(K).dot(grid_points))
323 | grid_points_mapped = grid_points_mapped / grid_points_mapped[2, :]
324 |
325 | x_map = np.reshape(grid_points_mapped[0, :], x_mesh.shape)
326 | y_map = np.reshape(grid_points_mapped[1, :], y_mesh.shape)
327 | x_map, y_map = x_map.astype('float32'), y_map.astype('float32')
328 | mapped_pattern = cv2.remap(pattern, x_map, y_map, cv2.INTER_LINEAR)
329 |
330 | pattern_processed, K_processed = post_process(args.pattern_type, mapped_pattern, K)
331 | # write settings to file
332 | settings = {
333 | 'imsize': imsize_processed,
334 | 'pattern': pattern_processed,
335 | 'baseline': baseline,
336 | 'K': K_processed,
337 | }
338 | out_path = out_root / f'settings.pkl'
339 | print(f'write settings to {out_path}')
340 | with open(str(out_path), 'wb') as f:
341 | pickle.dump(settings, f, pickle.HIGHEST_PROTOCOL)
342 |
343 | # start the job
344 | n_samples = 2 ** 10 + 2 ** 13
345 | # n_samples = 2048
346 | for idx in range(n_samples):
347 | parameters = (
348 | args.pattern_type, out_root, idx, n_samples, imsize_proj, imsize, pattern, K_proj, K, K_processed, baseline, blend_im, noise, track_length)
349 | create_data(*parameters)
--------------------------------------------------------------------------------
/data/data_manipulation.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | #
5 | # MIT License
6 | #
7 | # Copyright (c) 2019 autonomousvision
8 | #
9 | # Permission is hereby granted, free of charge, to any person obtaining a copy
10 | # of this software and associated documentation files (the "Software"), to deal
11 | # in the Software without restriction, including without limitation the rights
12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 | # copies of the Software, and to permit persons to whom the Software is
14 | # furnished to do so, subject to the following conditions:
15 | #
16 | # The above copyright notice and this permission notice shall be included in all
17 | # copies or substantial portions of the Software.
18 | #
19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 | # SOFTWARE.
26 | #
27 | #
28 | # MIT License
29 | #
30 | # Copyright, 2021 ams International AG
31 | #
32 | # Permission is hereby granted, free of charge, to any person obtaining a copy
33 | # of this software and associated documentation files (the "Software"), to deal
34 | # in the Software without restriction, including without limitation the rights
35 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
36 | # copies of the Software, and to permit persons to whom the Software is
37 | # furnished to do so, subject to the following conditions:
38 | #
39 | # The above copyright notice and this permission notice shall be included in all
40 | # copies or substantial portions of the Software.
41 | #
42 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
43 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
44 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
45 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
46 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
47 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
48 | # SOFTWARE.
49 |
50 | import numpy as np
51 | import cv2
52 |
53 | def read_pattern_file(pattern_type, pattern_size):
54 | if pattern_type == 'default':
55 | pattern_path = 'default_pattern.png'
56 | elif pattern_type == 'kinect':
57 | pattern_path = 'kinect_pattern.png'
58 | elif pattern_type == 'real':
59 | pattern_path = 'real_pattern.png'
60 |
61 | pattern = cv2.imread(pattern_path)
62 | pattern = pattern.astype(np.float32)
63 | pattern /= 255
64 |
65 | if pattern.ndim == 2:
66 | pattern = np.stack([pattern for idx in range(3)], axis=2)
67 |
68 | if pattern_type == 'default':
69 | pattern = np.rot90(np.flip(pattern, axis=1))
70 | elif pattern_type == 'kinect':
71 | min_dim = min(pattern.shape[0:2])
72 | start_h = (pattern.shape[0] - min_dim) // 2
73 | start_w = (pattern.shape[1] - min_dim) // 2
74 | pattern = pattern[start_h:start_h + min_dim, start_w:start_w + min_dim]
75 | pattern = cv2.resize(pattern, pattern_size, interpolation=cv2.INTER_LINEAR)
76 |
77 | return pattern
78 |
79 | def get_rotation_matrix(v0, v1):
80 | v0 = v0/np.linalg.norm(v0)
81 | v1 = v1/np.linalg.norm(v1)
82 | v = np.cross(v0,v1)
83 | c = np.dot(v0,v1)
84 | s = np.linalg.norm(v)
85 | I = np.eye(3)
86 | vXStr = '{} {} {}; {} {} {}; {} {} {}'.format(0, -v[2], v[1], v[2], 0, -v[0], -v[1], v[0], 0)
87 | k = np.matrix(vXStr)
88 | r = I + k + k @ k * ((1 -c)/(s**2))
89 | return np.asarray(r.astype(np.float32))
90 |
91 | def post_process(pattern_type, im, K = None):
92 | if pattern_type == 'real':
93 | im_processed = im[128:-128, 108:-108, ...].copy()
94 | im_processed = cv2.resize(im_processed, (432, 512), interpolation=cv2.INTER_LINEAR)
95 |
96 | if K is not None:
97 | K_processed = K.copy()
98 |
99 | K_processed[0, 0] = K_processed[0, 0] / 2
100 | K_processed[1, 1] = K_processed[1, 1] / 2
101 |
102 | K_processed[0, 2] = (K_processed[0, 2] - 108) / 2
103 | K_processed[1, 2] = (K_processed[1, 2] - 128) / 2
104 |
105 | return im_processed, K_processed
106 | else:
107 | return im_processed
108 | else:
109 | if K is not None:
110 | return im, K
111 | else:
112 | return im
113 |
114 | def augment_image(img,rng, amb=None, disp=None, primary_disp= None, sgm_disp= None, grad=None,max_shift=64,max_blur=1.5,max_noise=10.0,max_sp_noise=0.001):
115 |
116 | # get min/max values of image
117 | min_val = np.min(img)
118 | max_val = np.max(img)
119 |
120 | # init augmented image
121 | img_aug = img
122 | amb_aug = amb
123 |
124 | # init disparity correction map
125 | disp_aug = disp
126 | primary_disp_aug = primary_disp
127 | sgm_disp_aug = sgm_disp
128 | grad_aug = grad
129 |
130 | # apply affine transformation
131 | if max_shift>1:
132 |
133 | # affine parameters
134 | rows,cols = img.shape
135 | shear = 0
136 | shift = 0
137 | shear_correction = 0
138 | if rng.uniform(0,1)<0.75: shear = rng.uniform(-max_shift,max_shift) # shear with 75% probability
139 | else: shift = rng.uniform(-max_shift/2,max_shift) # shift with 25% probability
140 | if shear<0: shear_correction = -shear
141 |
142 | # affine transformation
143 | a = shear/float(rows)
144 | b = shift+shear_correction
145 |
146 | # warp image
147 | T = np.float32([[1,a,b],[0,1,0]])
148 | img_aug = cv2.warpAffine(img_aug,T,(cols,rows))
149 | if amb is not None:
150 | amb_aug = cv2.warpAffine(amb_aug,T,(cols,rows))
151 | if grad is not None:
152 | grad_aug = cv2.warpAffine(grad,T,(cols,rows))
153 |
154 | # disparity correction map
155 | col = a*np.array(range(rows))+b
156 | disp_delta = np.tile(col,(cols,1)).transpose()
157 | if disp is not None:
158 | disp_aug = cv2.warpAffine(disp+disp_delta,T,(cols,rows))
159 | if primary_disp is not None:
160 | primary_disp_aug = cv2.warpAffine(primary_disp+disp_delta,T,(cols,rows))
161 | if sgm_disp is not None:
162 | sgm_disp_aug = cv2.warpAffine(sgm_disp+disp_delta,T,(cols,rows))
163 |
164 | # gaussian smoothing
165 | if rng.uniform(0,1)<0.5:
166 | img_aug = cv2.GaussianBlur(img_aug,(5,5),rng.uniform(0.2,max_blur))
167 | if amb is not None:
168 | amb_aug = cv2.GaussianBlur(amb_aug,(5,5),rng.uniform(0.2,max_blur))
169 |
170 | # per-pixel gaussian noise
171 | img_aug = img_aug + rng.randn(*img_aug.shape)*rng.uniform(0.0,max_noise)/255.0
172 | if amb is not None:
173 | amb_aug = amb_aug + rng.randn(*amb_aug.shape)*rng.uniform(0.0,max_noise)/255.0
174 |
175 | # salt-and-pepper noise
176 | if rng.uniform(0,1)<0.5:
177 | ratio=rng.uniform(0.0,max_sp_noise)
178 | img_shape = img_aug.shape
179 | img_aug = img_aug.flatten()
180 | coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio))
181 | img_aug[coord] = max_val
182 | coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio))
183 | img_aug[coord] = min_val
184 | img_aug = np.reshape(img_aug, img_shape)
185 |
186 | # clip intensities back to [0,1]
187 | img_aug = np.maximum(img_aug,0.0)
188 | img_aug = np.minimum(img_aug,1.0)
189 |
190 | if amb is not None:
191 | amb_aug = np.maximum(amb_aug,0.0)
192 | amb_aug = np.minimum(amb_aug,1.0)
193 |
194 | # return image
195 | return img_aug, amb_aug, disp_aug, primary_disp_aug, sgm_disp_aug, grad_aug
196 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # MIT License
5 | #
6 | # Copyright (c) 2019 autonomousvision
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 | import h5py
27 | import numpy as np
28 | import pickle
29 | import cv2
30 | import os
31 |
32 | from . import base_dataset
33 | from .data_manipulation import augment_image
34 |
35 |
36 | class TrackSynDataset(base_dataset.BaseDataset):
37 | '''
38 | Load locally saved synthetic dataset
39 | '''
40 | def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False, load_flow_data = False, load_primary_data = False, load_pseudo_gt = False, data_type = 'synthetic'):
41 | super().__init__(train=train)
42 |
43 | self.settings_path = settings_path
44 | self.sample_paths = sample_paths
45 | self.data_aug = data_aug
46 | self.train = train
47 | self.track_length=track_length
48 | self.load_flow_data = load_flow_data
49 | self.load_primary_data = load_primary_data
50 | self.load_pseudo_gt = load_pseudo_gt
51 | self.data_type = data_type
52 | assert(track_length<=4)
53 |
54 | with open(str(settings_path), 'rb') as f:
55 | settings = pickle.load(f)
56 | self.imsizes = [(settings['imsize'][0] // (2 ** s), settings['imsize'][1] // (2 ** s)) for s in range(4)]
57 | self.patterns = []
58 | for imsize in self.imsizes:
59 | pat = cv2.resize(settings['pattern'], (imsize[1], imsize[0]), interpolation=cv2.INTER_LINEAR)
60 | self.patterns.append(pat)
61 | self.baseline = settings['baseline']
62 | self.K = settings['K']
63 | self.focal_lengths = [self.K[0,0]/(2**s) for s in range(4)]
64 |
65 | self.scale = len(self.imsizes)
66 |
67 | self.max_shift = 0
68 | self.max_blur = 0.5
69 | self.max_noise = 3.0
70 | self.max_sp_noise = 0.0005
71 |
72 | def __len__(self):
73 | return len(self.sample_paths)
74 |
75 | def __getitem__(self, idx):
76 | if not self.train:
77 | rng = self.get_rng(idx)
78 | else:
79 | rng = np.random.RandomState()
80 | sample_path = self.sample_paths[idx]
81 |
82 | if self.train:
83 | track_ind = np.random.permutation(4)[0:self.track_length]
84 | else:
85 | track_ind = np.arange(0, self.track_length)
86 |
87 | ret = {}
88 | ret['id'] = idx
89 |
90 | with h5py.File(os.path.join(sample_path,f'frames.hdf5'), "r") as sample_frames:
91 | # load imgs, at all scales
92 | for sidx in range(len(self.imsizes)):
93 | im = sample_frames['im']
94 | amb = sample_frames['ambient']
95 | grad = sample_frames['grad']
96 | if sidx == 0:
97 | ret[f'im{sidx}'] = np.stack([im[tidx, ...] for tidx in track_ind], axis=0)
98 | ret[f'ambient{sidx}'] = np.stack([amb[tidx, ...] for tidx in track_ind], axis=0)
99 | ret[f'grad{sidx}'] = np.stack([grad[tidx, ...] for tidx in track_ind], axis=0)
100 | else:
101 | ret[f'im{sidx}'] = np.stack([cv2.resize(im[tidx, 0, ...], self.imsizes[sidx][::-1])[None] for tidx in track_ind], axis=0)
102 | ret[f'ambient{sidx}'] = np.stack([cv2.resize(amb[tidx, 0, ...], self.imsizes[sidx][::-1])[None] for tidx in track_ind], axis=0)
103 | ret[f'grad{sidx}'] = np.stack([cv2.resize(grad[tidx, 0, ...], self.imsizes[sidx][::-1])[None] for tidx in track_ind], axis=0)
104 |
105 | # load disp and grad only at full resolution
106 | ret[f'disp0'] = np.stack([sample_frames['disp'][tidx, ...] for tidx in track_ind], axis=0)
107 | ret['R'] = np.stack([sample_frames['R'][tidx, ...] for tidx in track_ind], axis=0)
108 | ret['t'] = np.stack([sample_frames['t'][tidx, ...] for tidx in track_ind], axis=0)
109 | if self.data_type == 'real':
110 | ret[f'sgm_disp'] = np.stack([sample_frames['sgm_disp'][tidx, ...] for tidx in track_ind], axis=0)
111 |
112 | if self.load_flow_data:
113 | with h5py.File(os.path.join(sample_path, f'flow.hdf5'), "r") as sample_flow:
114 | for i0, tidx0 in enumerate(track_ind):
115 | for i1, tidx1 in enumerate(track_ind):
116 | if tidx0 != tidx1:
117 | ret[f'flow_{i0}{i1}'] = sample_flow[f'flow_{tidx0}{tidx1}'][:]
118 |
119 | if self.load_primary_data:
120 | with h5py.File(os.path.join(sample_path, f'single_frame_disp.hdf5'), "r") as sample_primary_disp:
121 | ret[f'primary_disp'] = np.stack([sample_primary_disp['disp'][tidx, ...] for tidx in track_ind], axis=0)
122 |
123 | if self.load_pseudo_gt:
124 | with h5py.File(os.path.join(sample_path, f'multi_frame_disp.hdf5'), "r") as sample_disp:
125 | ret[f'pseudo_gt'] = np.stack([sample_disp['disp'][tidx, ...] for tidx in track_ind], axis=0)
126 |
127 | #### apply data augmentation at different scales seperately, only work for max_shift=0
128 | if self.data_aug:
129 | for sidx in range(len(self.imsizes)):
130 | if sidx==0:
131 | img = ret[f'im{sidx}']
132 | amb = ret[f'ambient{sidx}']
133 | disp = ret[f'disp{sidx}']
134 | if self.load_primary_data:
135 | primary_disp = ret[f'primary_disp']
136 | else:
137 | primary_disp = None
138 | if self.data_type == 'real':
139 | sgm_disp = ret[f'sgm_disp']
140 | else:
141 | sgm_disp = None
142 | grad = ret[f'grad{sidx}']
143 | img_aug = np.zeros_like(img)
144 | amb_aug = np.zeros_like(img)
145 | disp_aug = np.zeros_like(img)
146 | primary_disp_aug = np.zeros_like(img)
147 | sgm_disp_aug = np.zeros_like(img)
148 | grad_aug = np.zeros_like(img)
149 | for i in range(img.shape[0]):
150 | if self.load_primary_data:
151 | primary_disp_i = primary_disp[i,0]
152 | else:
153 | primary_disp_i = None
154 | if self.data_type == 'real':
155 | sgm_disp_i = sgm_disp[i,0]
156 | else:
157 | sgm_disp_i = None
158 | img_aug_, amb_aug_, disp_aug_, primary_disp_aug_, sgm_disp_aug_, grad_aug_ = augment_image(img[i,0],rng,
159 | amb=amb[i,0],disp=disp[i,0],primary_disp=primary_disp_i, sgm_disp= sgm_disp_i,grad=grad[i,0],
160 | max_shift=self.max_shift, max_blur=self.max_blur,
161 | max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
162 | img_aug[i] = img_aug_[None].astype(np.float32)
163 | amb_aug[i] = amb_aug_[None].astype(np.float32)
164 | disp_aug[i] = disp_aug_[None].astype(np.float32)
165 | if self.load_primary_data:
166 | primary_disp_aug[i] = primary_disp_aug_[None].astype(np.float32)
167 | if self.data_type == 'real':
168 | sgm_disp_aug[i] = sgm_disp_aug_[None].astype(np.float32)
169 | grad_aug[i] = grad_aug_[None].astype(np.float32)
170 | ret[f'im{sidx}'] = img_aug
171 | ret[f'ambient{sidx}'] = amb_aug
172 | ret[f'disp{sidx}'] = disp_aug
173 | if self.load_primary_data:
174 | ret[f'primary_disp'] = primary_disp_aug
175 | if self.data_type == 'real':
176 | ret[f'sgm_disp'] = sgm_disp_aug
177 | ret[f'grad{sidx}'] = grad_aug
178 | else:
179 | img = ret[f'im{sidx}']
180 | img_aug = np.zeros_like(img)
181 | for i in range(img.shape[0]):
182 | img_aug_, _, _, _, _, _ = augment_image(img[i,0],rng,
183 | max_shift=self.max_shift, max_blur=self.max_blur,
184 | max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
185 | img_aug[i] = img_aug_[None].astype(np.float32)
186 | ret[f'im{sidx}'] = img_aug
187 |
188 | return ret
189 |
190 | def getK(self, sidx=0):
191 | K = self.K.copy() / (2**sidx)
192 | K[2,2] = 1
193 | return K
194 |
195 |
196 |
197 | if __name__ == '__main__':
198 | pass
199 |
200 |
--------------------------------------------------------------------------------
/data/default_pattern.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/idiap/DepthInSpace/fe759807f82df4c48c16b97f061718175ea0e6e9/data/default_pattern.png
--------------------------------------------------------------------------------
/data/kinect_pattern.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/idiap/DepthInSpace/fe759807f82df4c48c16b97f061718175ea0e6e9/data/kinect_pattern.png
--------------------------------------------------------------------------------
/data/presave_disp.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # MIT License
5 | #
6 | # Copyright, 2021 ams International AG
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 | import numpy as np
27 | from pathlib import Path
28 | import os
29 | import sys
30 | import argparse
31 | import json
32 | import pickle
33 | import h5py
34 | import torch
35 | from tqdm import tqdm
36 |
37 | sys.path.append('../')
38 | from model import networks
39 | from model import multi_frame_networks
40 |
41 | if __name__ == '__main__':
42 |
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument('architecture',
45 | help='Select the architecture to produce disparity (single_frame/multi_frame)',
46 | choices=['single_frame', 'multi_frame'], type=str)
47 | parser.add_argument('--epoch',
48 | help='Epoch whose results will be pre-saved',
49 | default=-1, type=int)
50 | args = parser.parse_args()
51 |
52 | # output directory
53 | config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'config.json'))
54 | with open(config_path) as fp:
55 | config = json.load(fp)
56 | data_root = Path(config['DATA_DIR'])
57 | output_dir = Path(config['OUTPUT_DIR'])
58 |
59 | model_path = output_dir / args.architecture / f'net_{args.epoch:04d}.params'
60 | settings_path = data_root / 'settings.pkl'
61 | with open(str(settings_path), 'rb') as f:
62 | settings = pickle.load(f)
63 | imsizes = [(settings['imsize'][0] // (2 ** s), settings['imsize'][1] // (2 ** s)) for s in range(4)]
64 | K = settings['K']
65 | Ki = np.linalg.inv(K)
66 | baseline = settings['baseline']
67 | pat = settings['pattern']
68 |
69 | d2d = networks.DispToDepth(focal_length= K[0, 0],baseline= baseline).cuda()
70 | lcn_in = networks.LCN(5, 0.05).cuda()
71 |
72 | pat = pat.mean(axis=2)
73 | pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda')
74 | pat_lcn, _ = lcn_in(pat)
75 | pat_cat = torch.cat((pat_lcn, pat), dim=1)
76 |
77 | if args.architecture == 'single_frame':
78 | net = networks.DispDecoder(channels_in=2, max_disp=128, imsizes=imsizes).cuda().eval()
79 | elif args.architecture == 'multi_frame':
80 | net = multi_frame_networks.FuseNet(imsize=imsizes[0], K=K, baseline=baseline, max_disp=128).cuda().eval()
81 |
82 | net.load_state_dict(torch.load(model_path))
83 |
84 | with torch.no_grad():
85 | sample_list = [os.path.join(data_root, o) for o in os.listdir(data_root) if os.path.isdir(os.path.join(data_root,o))]
86 | for sample_path in tqdm(sample_list, ascii = True):
87 | with h5py.File(os.path.join(sample_path, f'frames.hdf5'), "r") as frames_sample:
88 | if args.architecture == 'single_frame':
89 | im = torch.tensor(frames_sample['im']).cuda()
90 | im_lcn, im_std = lcn_in(im)
91 | im = torch.cat([im_lcn, im], dim=1)
92 | disp = net(im)[0].cpu().numpy()
93 | elif args.architecture == 'multi_frame':
94 | im = torch.tensor(frames_sample['im']).cuda()
95 | im_lcn, im_std = lcn_in(im)
96 | im = torch.cat([im_lcn, im], dim=1)
97 |
98 | amb = torch.tensor(frames_sample['ambient']).cuda()
99 |
100 | R = torch.tensor(frames_sample['R']).cuda()
101 | t = torch.tensor(frames_sample['t']).cuda()
102 |
103 | flow = {}
104 | with h5py.File(os.path.join(sample_path, f'flow.hdf5'), "r") as sample_flow:
105 | for i0 in range(4):
106 | for i1 in range(4):
107 | if i0 != i1:
108 | flow[f'flow_{i0}{i1}'] = torch.tensor(sample_flow[f'flow_{i0}{i1}'][:]).cuda()
109 |
110 | with h5py.File(os.path.join(sample_path, f'single_frame_disp.hdf5'), "r") as sample_primary_disp:
111 | primary_disp = torch.tensor(sample_primary_disp['disp'][:]).cuda()
112 |
113 | disp = net(im.unsqueeze(1), amb.unsqueeze(1), primary_disp.unsqueeze(1), d2d(primary_disp.unsqueeze(1)), R.unsqueeze(1), t.unsqueeze(1), flow).detach().cpu().numpy()
114 | disp = disp[:, 0]
115 |
116 | with h5py.File(os.path.join(sample_path, f'{args.architecture}_disp.hdf5'), "w") as f:
117 | f.create_dataset('disp', data=disp)
--------------------------------------------------------------------------------
/data/presave_optical_flow_data.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # MIT License
5 | #
6 | # Copyright, 2021 ams International AG
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 | from pathlib import Path
27 | import os
28 | import json
29 | import pickle
30 |
31 | if __name__ == '__main__':
32 |
33 | # output directory
34 | config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'config.json'))
35 | with open(config_path) as fp:
36 | config = json.load(fp)
37 | data_root = Path(config['DATA_DIR'])
38 | liteflownet_path = Path(config['LITEFLOWNET_DIR'])
39 |
40 | script_path = str(liteflownet_path / 'run.py')
41 | data_path = str(data_root)
42 | os.environ['PYTHONPATH'] = str(liteflownet_path) + os.pathsep + str(liteflownet_path / 'correlation')
43 | os.system('python ' + script_path + ' --model default' + ' --data_path ' + data_path)
44 |
--------------------------------------------------------------------------------
/data/real_pattern.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/idiap/DepthInSpace/fe759807f82df4c48c16b97f061718175ea0e6e9/data/real_pattern.png
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # Copyright (c) 2021 ams International AG
5 | #
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
7 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation
8 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
9 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
10 | #
11 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
14 | # OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
15 | # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
16 | # IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/model/ext_functions.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # MIT License
5 | #
6 | # Copyright (c) 2019 autonomousvision
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 | import torch
27 | import sys
28 | import json
29 | from pathlib import Path
30 | import os
31 |
32 | config_path = os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..', 'config.json'))
33 | with open(config_path) as fp:
34 | config = json.load(fp)
35 | CTD_path = Path(config['CTD_DIR'])
36 | sys.path.append(str(CTD_path / 'torchext'))
37 |
38 | import ext_cpu
39 | import ext_cuda
40 |
41 | class NNFunction(torch.autograd.Function):
42 | @staticmethod
43 | def forward(ctx, in0, in1):
44 | args = (in0, in1)
45 | if in0.is_cuda:
46 | out = ext_cuda.nn_cuda(*args)
47 | else:
48 | out = ext_cpu.nn_cpu(*args)
49 | return out
50 |
51 | @staticmethod
52 | def backward(ctx, grad_out):
53 | return None, None
54 |
55 | def nn(in0, in1):
56 | return NNFunction.apply(in0, in1)
57 |
58 |
59 | class CrossCheckFunction(torch.autograd.Function):
60 | @staticmethod
61 | def forward(ctx, in0, in1):
62 | args = (in0, in1)
63 | if in0.is_cuda:
64 | out = ext_cuda.crosscheck_cuda(*args)
65 | else:
66 | out = ext_cpu.crosscheck_cpu(*args)
67 | return out
68 |
69 | @staticmethod
70 | def backward(ctx, grad_out):
71 | return None, None
72 |
73 | def crosscheck(in0, in1):
74 | return CrossCheckFunction.apply(in0, in1)
75 |
76 | class ProjNNFunction(torch.autograd.Function):
77 | @staticmethod
78 | def forward(ctx, xyz0, xyz1, K, patch_size):
79 | args = (xyz0, xyz1, K, patch_size)
80 | if xyz0.is_cuda:
81 | out = ext_cuda.proj_nn_cuda(*args)
82 | else:
83 | out = ext_cpu.proj_nn_cpu(*args)
84 | return out
85 |
86 | @staticmethod
87 | def backward(ctx, grad_out):
88 | return None, None, None, None
89 |
90 | def proj_nn(xyz0, xyz1, K, patch_size):
91 | return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
92 |
93 |
94 |
95 | class XCorrVolFunction(torch.autograd.Function):
96 | @staticmethod
97 | def forward(ctx, in0, in1, n_disps, block_size):
98 | args = (in0, in1, n_disps, block_size)
99 | if in0.is_cuda:
100 | out = ext_cuda.xcorrvol_cuda(*args)
101 | else:
102 | out = ext_cpu.xcorrvol_cpu(*args)
103 | return out
104 |
105 | @staticmethod
106 | def backward(ctx, grad_out):
107 | return None, None, None, None
108 |
109 | def xcorrvol(in0, in1, n_disps, block_size):
110 | return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
111 |
112 |
113 |
114 |
115 | class PhotometricLossFunction(torch.autograd.Function):
116 | @staticmethod
117 | def forward(ctx, es, ta, block_size, type, eps):
118 | args = (es, ta, block_size, type, eps)
119 | ctx.save_for_backward(es, ta)
120 | ctx.block_size = block_size
121 | ctx.type = type
122 | ctx.eps = eps
123 | if es.is_cuda:
124 | out = ext_cuda.photometric_loss_forward(*args)
125 | else:
126 | out = ext_cpu.photometric_loss_forward(*args)
127 | return out
128 |
129 | @staticmethod
130 | def backward(ctx, grad_out):
131 | es, ta = ctx.saved_tensors
132 | block_size = ctx.block_size
133 | type = ctx.type
134 | eps = ctx.eps
135 | args = (es, ta, grad_out.contiguous(), block_size, type, eps)
136 | if grad_out.is_cuda:
137 | grad_es = ext_cuda.photometric_loss_backward(*args)
138 | else:
139 | grad_es = ext_cpu.photometric_loss_backward(*args)
140 | return grad_es, None, None, None, None
141 |
142 | def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
143 | type = type.lower()
144 | if type == 'mse':
145 | type = 0
146 | elif type == 'sad':
147 | type = 1
148 | elif type == 'census_mse':
149 | type = 2
150 | elif type == 'census_sad':
151 | type = 3
152 | else:
153 | raise Exception('invalid loss type')
154 | return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
155 |
156 | def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
157 | type = type.lower()
158 | p = block_size // 2
159 | es_pad = torch.nn.functional.pad(es, (p,p,p,p), mode='replicate')
160 | ta_pad = torch.nn.functional.pad(ta, (p,p,p,p), mode='replicate')
161 | es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
162 | ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
163 | es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
164 | ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
165 | if type == 'mse':
166 | ref = (es_uf - ta_uf)**2
167 | elif type == 'sad':
168 | ref = torch.abs(es_uf - ta_uf)
169 | elif type == 'census_mse' or type == 'census_sad':
170 | des = es_uf - es.unsqueeze(2)
171 | dta = ta_uf - ta.unsqueeze(2)
172 | h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps))
173 | h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps))
174 | diff = h_des - h_dta
175 | if type == 'census_mse':
176 | ref = diff * diff
177 | elif type == 'census_sad':
178 | ref = torch.abs(diff)
179 | else:
180 | raise Exception('invalid loss type')
181 | ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
182 | ref = torch.sum(ref, dim=1, keepdim=True) / block_size**2
183 | return ref
--------------------------------------------------------------------------------
/model/multi_frame_networks.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # MIT License
5 | #
6 | # Copyright, 2021 ams International AG
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 | import cv2
27 |
28 | import torch
29 | import numpy as np
30 |
31 | from torch.utils.checkpoint import checkpoint
32 |
33 | from .networks import TimedModule, OutputLayerFactory
34 |
35 |
36 | def merge_tl_bs(x):
37 | return x.contiguous().view(-1, *x.shape[2:])
38 |
39 | def split_tl_bs(x, tl, bs):
40 | return x.contiguous().view(tl, bs, *x.shape[1:])
41 |
42 | def resize_like(x, target):
43 | x_shape = x.shape[:-3]
44 | height = target.shape[-2]
45 | width = target.shape[-1]
46 |
47 | out = torch.nn.functional.interpolate(x.contiguous().view(-1, *x.shape[-3:]), size=(height, width), mode='bilinear',
48 | align_corners=True)
49 | out = out.view(*x_shape, *out.shape[1:])
50 |
51 | return out
52 |
53 |
54 | def resize_flow_like(flow, target):
55 | height = target.shape[-2]
56 | width = target.shape[-1]
57 |
58 | out = {}
59 | for key, val in flow.items():
60 | flow_height = val.shape[-2]
61 | flow_width = val.shape[-1]
62 |
63 | resized_flow = torch.nn.functional.interpolate(val, size=(height, width), mode='bilinear', align_corners=True)
64 | resized_flow[:, 0, :, :] *= float(width) / float(flow_width)
65 | resized_flow[:, 1, :, :] *= float(height) / float(flow_height)
66 | out[key] = resized_flow
67 |
68 | return out
69 |
70 | def resize_flow_masks_like(flow_masks, target):
71 | height = target.shape[-2]
72 | width = target.shape[-1]
73 |
74 | with torch.no_grad():
75 | out = {}
76 | for key, val in flow_masks.items():
77 | resized_mask = torch.nn.functional.interpolate(val, size=(height, width), mode='bilinear', align_corners=True)
78 |
79 | out[key] = (resized_mask > 0.5).float()
80 |
81 | return out
82 |
83 | def warp(x, flow):
84 | width = x.shape[-1]
85 | height = x.shape[-2]
86 |
87 | u, v = np.meshgrid(range(width), range(height))
88 | u = torch.from_numpy(u.astype('float32')).to(x.device)
89 | v = torch.from_numpy(v.astype('float32')).to(x.device)
90 |
91 | uv_prj = flow.clone().permute(0, 2, 3, 1)
92 | uv_prj[..., 0] += u
93 | uv_prj[..., 1] += v
94 |
95 | uv_prj[..., 0] = 2 * (uv_prj[..., 0] / (width - 1) - 0.5)
96 | uv_prj[..., 1] = 2 * (uv_prj[..., 1] / (height - 1) - 0.5)
97 | x_prj = torch.nn.functional.grid_sample(x, uv_prj, padding_mode='zeros', align_corners=True)
98 |
99 | return x_prj
100 |
101 | class FuseNet(TimedModule):
102 | '''
103 | Fuse Net
104 | '''
105 | def __init__(self, imsize, K, baseline, track_length = 4, block_num = 4, channels = 32, max_disp= 128, movement_mask_en = 1):
106 | super(FuseNet, self).__init__(mod_name='FuseNet')
107 | self.movement_mask_en = movement_mask_en
108 |
109 | self.im_height = imsize[0]
110 | self.im_width = imsize[1]
111 | self.core_height = self.im_height // 2
112 | self.core_width = self.im_width // 2
113 | self.K = K
114 | self.Ki = np.linalg.inv(K)
115 | self.baseline = baseline
116 | self.track_length = track_length
117 | self.block_num = block_num
118 | self.channels = channels
119 | self.max_disp = max_disp
120 |
121 | u, v = np.meshgrid(range(self.im_width), range(self.im_height))
122 | u = cv2.resize(u, (self.core_width, self.core_height), interpolation = cv2.INTER_NEAREST)
123 | v = cv2.resize(v, (self.core_width, self.core_height), interpolation = cv2.INTER_NEAREST)
124 | uv = np.stack((u,v,np.ones_like(u)), axis=2).reshape(-1,3)
125 |
126 | ray = uv @ self.Ki.T
127 | ray = ray.reshape(1, 1,-1,3).astype(np.float32)
128 | self.ray = torch.from_numpy(ray).cuda()
129 |
130 | self.conv1 = self.conv(4, self.channels // 2, kernel_size=4, stride=2)
131 | self.conv2 = self.conv(self.channels // 2, self.channels // 2, kernel_size=3, stride=1)
132 |
133 | # self.conv3 = self.conv(self.channels // 2, self.channels, kernel_size=4, stride=2)
134 | self.conv3 = self.conv(self.channels // 2, self.channels, kernel_size=3, stride=1)
135 | self.conv4 = self.conv(self.channels, self.channels, kernel_size=3, stride=1)
136 |
137 | self.res1 = ResNetBlock(self.channels)
138 | self.res2 = ResNetBlock(self.channels)
139 | self.res3 = ResNetBlock(self.channels)
140 |
141 | self.blocks = torch.nn.ModuleList([Block2D3D(channels = self.channels, tl= self.track_length) for i in range(self.block_num)])
142 |
143 | self.upconv1 = self.upconv(self.channels, self.channels)
144 | self.upconv2 = self.upconv(self.channels, self.channels)
145 |
146 | self.amb_conv = self.conv(1, 16, kernel_size=3, stride=1)
147 | self.amb_res1 = ResNetBlock(16)
148 | self.amb_res2 = ResNetBlock(16)
149 |
150 | self.ref_conv = self.conv(16 + self.channels, 32, kernel_size=3, stride=1)
151 | self.ref_res1 = ResNetBlock(32)
152 | self.ref_res2 = ResNetBlock(32)
153 | self.ref_res3 = ResNetBlock(32)
154 |
155 | self.final_conv = self.conv(32, 16, kernel_size=3, stride=1)
156 |
157 | self.predict_disp = OutputLayerFactory( type='disp', params={ 'alpha': self.max_disp, 'beta': 0, 'gamma': 1, 'offset': 3})(16)
158 |
159 | def conv(self, in_planes, out_planes, kernel_size=3, stride=1):
160 | return torch.nn.Sequential(
161 | torch.nn.ZeroPad2d((kernel_size - 1) // 2),
162 | torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=0),
163 | torch.nn.SELU(inplace=True)
164 | )
165 |
166 | def upconv(self, in_planes, out_planes):
167 | return torch.nn.Sequential(
168 | torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1),
169 | torch.nn.SELU(inplace=True)
170 | )
171 |
172 | def unproject(self, d, R, t):
173 | tl = d.shape[0]
174 | bs = d.shape[1]
175 | xyz = d.view(tl, bs, -1, 1) * self.ray
176 | xyz = xyz - t.view(tl, bs, 1, 3)
177 | xyz = torch.matmul(xyz, R)
178 |
179 | return xyz
180 |
181 | def change_view_angle(self, xyz, R, t):
182 | xyz_changed = torch.matmul(xyz, R.transpose(1,2))
183 | xyz_changed = xyz_changed + t.unsqueeze(1).unsqueeze(0)
184 |
185 | return xyz_changed
186 |
187 | def gather_warped_xyz(self, tidx, xyz, depth, flow, amb):
188 | tl = xyz.shape[0]
189 | bs = xyz.shape[1]
190 |
191 | frame_inds = [i for i in range(tl) if i != tidx]
192 |
193 | warped_xyz = []
194 | warped_xyz.append(xyz[tidx].transpose(1, 2).view(bs, 3, self.core_height, self.core_width))
195 |
196 | warped_mask = []
197 | warped_mask.append(torch.ones(bs, 1, self.core_height, self.core_width).to(xyz.device))
198 |
199 | for j in frame_inds:
200 | warped_xyz.append(warp(xyz[j].transpose(1, 2).view(bs, 3, self.core_height, self.core_width), flow[f'flow_{tidx}{j}']))
201 |
202 | with torch.no_grad():
203 | flow0 = flow[f'flow_{tidx}{j}'].detach()
204 | flow10 = warp(flow[f'flow_{j}{tidx}'].detach(), flow0)
205 | fb_mask = ((flow0.detach() + flow10) ** 2).sum(dim=1) < 0.5 + 0.01 * (
206 | (flow0.detach() ** 2).sum(dim=1) + (flow10 ** 2).sum(dim=1))
207 | fb_mask = fb_mask.type(torch.float32).unsqueeze(1)
208 |
209 | warped_mask.append(fb_mask)
210 |
211 | warped_xyz = torch.stack(warped_xyz, dim=0)
212 | warped_mask = torch.stack(warped_mask, dim=0)
213 |
214 | return warped_xyz, warped_mask
215 |
216 | def pre_process(self, input_data, d, checkpoint_var):
217 | out_conv1 = self.conv1(torch.cat([input_data, d], dim= 1))
218 | out_conv2 = self.conv2(out_conv1)
219 |
220 | out_conv3 = self.conv3(out_conv2)
221 | out_conv4 = self.conv4(out_conv3)
222 |
223 | out_res1 = self.res1(out_conv4)
224 | out_res2 = self.res2(out_res1)
225 | feat = self.res3(out_res2)
226 |
227 | return feat
228 |
229 | def process_amb(self, amb, feat):
230 | out_amb_conv = self.amb_conv(merge_tl_bs(amb))
231 | out_amb_res1 = self.amb_res1(out_amb_conv)
232 | out_amb_res2 = self.amb_res2(out_amb_res1)
233 |
234 | out_process_upc = self.process_upc(feat, out_amb_res2)
235 |
236 | return out_process_upc
237 |
238 | def process_upc(self, feat, out_process_amb):
239 | # out_upconv1 = self.upconv1(feat)
240 | # out_upconv2 = self.upconv2(out_upconv1)
241 | # out_upconv = out_upconv2
242 |
243 | # out_upconv1 = self.upconv1(feat)
244 | # out_upconv = out_upconv1
245 |
246 | out_upconv = torch.nn.functional.interpolate(feat, size=(self.im_height, self.im_width), mode='bilinear',
247 | align_corners=True)
248 |
249 | out_ref_conv = self.ref_conv(torch.cat([out_upconv, out_process_amb], dim=1))
250 |
251 | return out_ref_conv
252 |
253 | def post_process(self, feat, amb):
254 | out_process_amb = checkpoint(self.process_amb, amb, feat, preserve_rng_state= False)
255 | # out_process_amb = self.process_amb(amb, feat)
256 |
257 | out_ref_res1 = checkpoint(self.ref_res1, out_process_amb, preserve_rng_state= False)
258 | # out_ref_res1 = self.ref_res1(out_process_amb)
259 | out_ref_res2 = checkpoint(self.ref_res2, out_ref_res1, preserve_rng_state= False)
260 | # out_ref_res2 = self.ref_res2(out_ref_res1)
261 | out_ref_res3 = checkpoint(self.ref_res3, out_ref_res2, preserve_rng_state= False)
262 | # out_ref_res3 = self.ref_res3(out_ref_res2)
263 |
264 | out_final_conv = self.final_conv(out_ref_res3)
265 | disp = self.predict_disp(out_final_conv)
266 |
267 | return disp
268 |
269 | def tforward(self, ir, amb, d, depth, R, t, flow):
270 | tl = ir.shape[0]
271 | bs = ir.shape[1]
272 |
273 | input_data = merge_tl_bs(torch.cat((ir, amb), 2))
274 | checkpoint_var = torch.tensor([0.0]).to(input_data.device).requires_grad_(True)
275 | # feat = checkpoint(self.pre_process, input_data, merge_tl_bs(d), checkpoint_var, preserve_rng_state= False)
276 | feat = self.pre_process(input_data, merge_tl_bs(d), checkpoint_var)
277 |
278 | ###### Block Part
279 | core_feat = split_tl_bs(feat, tl, bs)
280 | core_depth = resize_like(depth, core_feat)
281 | core_flow = resize_flow_like(flow, core_feat)
282 | core_amb = resize_like(amb, core_feat)
283 | xyz = self.unproject(core_depth, R, t)
284 |
285 | warped_xyz = []
286 | warped_mask = []
287 | for tidx in range(tl):
288 | xyz_changed = self.change_view_angle(xyz, R[tidx], t[tidx])
289 | w_xyz, w_mask = self.gather_warped_xyz(tidx, xyz_changed, core_depth, core_flow, core_amb)
290 |
291 | warped_xyz.append(w_xyz)
292 | warped_mask.append(w_mask)
293 | warped_xyz = torch.stack(warped_xyz, dim=0)
294 | warped_mask = torch.stack(warped_mask, dim=0)
295 |
296 | for block in self.blocks:
297 | core_feat = block(core_feat, warped_xyz, warped_mask, core_flow)
298 |
299 | feat = merge_tl_bs(core_feat)
300 | #### End of Block Part
301 |
302 | disp = self.post_process(feat, amb)
303 | out = disp.view(tl, bs, *disp.shape[1:])
304 |
305 | return out
306 |
307 | class Block2D3D(TimedModule):
308 | def __init__(self, channels, tl):
309 | super(Block2D3D, self).__init__(mod_name='Block2D3D')
310 |
311 | self.channels = channels
312 | self.tl = tl
313 |
314 | self.conv_mf = self.conv(self.channels * self.tl, self.channels, kernel_size=1, stride=1, activation='none')
315 |
316 | self.conv1_1 = self.conv(self.channels, self.channels, kernel_size=3, stride=1, activation='relu')
317 | self.conv1_2 = self.conv(self.channels, self.channels, kernel_size=3, stride=1, activation='relu')
318 |
319 | self.conv2_1 = self.conv(self.channels, self.channels, kernel_size=4, stride=2, activation='relu')
320 | self.conv2_2 = self.conv(self.channels, self.channels, kernel_size=3, stride=1, activation='relu')
321 |
322 | self.conv_fuse = self.conv(self.channels * 3, self.channels, kernel_size=3, stride=1, activation='none')
323 |
324 | # self.conv_res = self.conv(self.channels, self.channels, kernel_size=3, stride=1, activation='relu')
325 | self.activation_res = torch.nn.SELU(inplace=True)
326 |
327 | self.conv3d_1 = Conv3D(channels_in= self.channels, channels_out= self.channels, tl= self.tl, stride= 2)
328 | self.conv3d_2 = Conv3D(channels_in= self.channels, channels_out= self.channels, tl= self.tl, stride= 1)
329 |
330 | def conv(self, in_planes, out_planes, kernel_size=3, stride=1, activation = 'none'):
331 | if activation == 'none':
332 | return torch.nn.Sequential(
333 | torch.nn.ZeroPad2d((kernel_size - 1) // 2),
334 | torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=0),
335 | # torch.nn.BatchNorm2d(out_planes)
336 | torch.nn.GroupNorm(num_groups= 1, num_channels= out_planes)
337 | )
338 | elif activation == 'relu':
339 | return torch.nn.Sequential(
340 | torch.nn.ZeroPad2d((kernel_size - 1) // 2),
341 | torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=0),
342 | torch.nn.SELU(inplace=True),
343 | # torch.nn.BatchNorm2d(out_planes),
344 | torch.nn.GroupNorm(num_groups=1, num_channels=out_planes)
345 | )
346 |
347 | def gather_warped_feat(self, tidx, feat, flow):
348 | tl = feat.shape[0]
349 |
350 | frame_inds = [i for i in range(tl) if i != tidx]
351 |
352 | warped_feat = []
353 | warped_feat.append(feat[tidx])
354 |
355 | for j in frame_inds:
356 | warped_feat.append(warp(feat[j], flow[f'flow_{tidx}{j}']))
357 |
358 | warped_feat = torch.stack(warped_feat, dim=0)
359 |
360 | return warped_feat
361 |
362 | def tforward(self, feat, warped_xyz, warped_mask, flow):
363 | self.flow = flow
364 |
365 | out_conv3d_1, warped_feat = checkpoint(self.fwd_3d_1, feat, warped_xyz, warped_mask, preserve_rng_state= False)
366 | # out_conv3d_1, warped_feat = self.fwd_3d_1(feat, warped_xyz, warped_mask)
367 |
368 | out_conv3d_2, _ = checkpoint(self.fwd_3d_2, out_conv3d_1, warped_xyz, warped_mask, preserve_rng_state= False)
369 | # out_conv3d_2, _ = self.fwd_3d_2(out_conv3d_1, warped_xyz, warped_mask)
370 |
371 | out = checkpoint(self.fwd_2d, feat, warped_feat, warped_mask, out_conv3d_2, preserve_rng_state= False)
372 | # out = self.fwd_2d(feat, warped_feat, warped_mask, out_conv3d_2)
373 |
374 | return out
375 |
376 | def fwd_3d_1(self, feat, warped_xyz, warped_mask):
377 | tl = feat.shape[0]
378 |
379 | warped_feat = []
380 | out_conv3d = []
381 | for tidx in range(tl):
382 | warped_feat.append(self.gather_warped_feat(tidx, feat, self.flow))
383 | out_conv3d.append(self.conv3d_1(warped_xyz[tidx], warped_feat[-1], warped_mask[tidx]))
384 | warped_feat = torch.stack(warped_feat, dim=0)
385 | out_conv3d = torch.stack(out_conv3d, dim=0)
386 |
387 | return out_conv3d, warped_feat
388 |
389 | def fwd_3d_2(self, feat, warped_xyz, warped_mask):
390 | tl = feat.shape[0]
391 |
392 | resized_flow = resize_flow_like(self.flow, feat)
393 | resized_warped_xyz = resize_like(warped_xyz, feat)
394 | resized_warped_mask = (resize_like(warped_mask, feat) > 0.5).float()
395 |
396 | warped_feat = []
397 | out_conv3d = []
398 | for tidx in range(tl):
399 | warped_feat.append(self.gather_warped_feat(tidx, feat, resized_flow))
400 | out_conv3d.append(self.conv3d_2(resized_warped_xyz[tidx], warped_feat[-1], resized_warped_mask[tidx]))
401 | warped_feat = torch.stack(warped_feat, dim=0)
402 | out_conv3d = torch.stack(out_conv3d, dim=0)
403 |
404 | return out_conv3d, warped_feat
405 |
406 | def fwd_2d(self, feat, warped_feat, warped_mask, out_conv3d_2):
407 | tl = feat.shape[0]
408 | bs = feat.shape[1]
409 |
410 | warped_feat_2d = (warped_feat * warped_mask / warped_mask.mean(dim=1, keepdim=True)).transpose(1, 2)
411 | warped_feat_2d = warped_feat_2d.reshape(tl * bs, -1, *warped_feat_2d.shape[4:])
412 |
413 | out_conv_mf = self.conv_mf(warped_feat_2d)
414 |
415 | out_conv1_1 = self.conv1_1(out_conv_mf)
416 | out_conv1_2 = self.conv1_2(out_conv1_1)
417 |
418 | out_conv2_1 = self.conv2_1(out_conv_mf)
419 | out_conv2_2 = self.conv2_2(out_conv2_1)
420 | out_ups2 = torch.nn.functional.interpolate(out_conv2_2, scale_factor=2, mode='bilinear', align_corners=True)
421 |
422 | out_ups3d = torch.nn.functional.interpolate(merge_tl_bs(out_conv3d_2), scale_factor=2, mode='bilinear', align_corners=True)
423 |
424 | ### Fusion Part
425 | out_fuse = torch.cat((out_conv1_2, out_ups2, out_ups3d), dim= 1)
426 | out_conv_fuse = self.conv_fuse(out_fuse)
427 |
428 | out = self.activation_res(split_tl_bs(out_conv_fuse, tl, bs) + feat)
429 |
430 | return out
431 |
432 | class Conv3D(TimedModule):
433 | def __init__(self, channels_in, channels_out, neighbors = 9, tl = 4, ksize = 3, stride = 1, radius_sq = 0.04):
434 | super(Conv3D, self).__init__(mod_name='Conv3D')
435 | self.channels_in = channels_in
436 | self.channels_out = channels_out
437 | self.neighbors = neighbors
438 | self.tl = tl
439 | self.ksize = ksize
440 | self.stride = stride
441 | self.radius_sq = radius_sq
442 |
443 | self.dense1 = self.dense(3, self.channels_out // 2, activation= 'selu')
444 | self.dense2 = self.dense(self.channels_out // 2, self.channels_out, activation= 'selu')
445 |
446 | self.w = torch.nn.Parameter(torch.zeros([self.channels_out, self.channels_out]))
447 | torch.nn.init.xavier_uniform_(self.w, gain=0.1)
448 |
449 | self.activation = torch.nn.SELU(inplace=True)
450 | # self.bn = torch.nn.BatchNorm2d(channels_out)
451 | self.bn = torch.nn.GroupNorm(num_groups= 1, num_channels= channels_out)
452 |
453 | def dense(self, in_planes, out_planes, activation = 'selu'):
454 | if activation == 'selu':
455 | return torch.nn.Sequential(
456 | torch.nn.Linear(in_planes, out_planes),
457 | torch.nn.SELU(inplace=True),
458 | )
459 | elif activation == 'softmax':
460 | return torch.nn.Sequential(
461 | torch.nn.Linear(in_planes, out_planes),
462 | torch.nn.Softmax(dim = -1)
463 | )
464 | elif activation == 'none':
465 | return torch.nn.Sequential(
466 | torch.nn.Linear(in_planes, out_planes),
467 | )
468 |
469 | def tforward(self, xyz, feat, mask, checkpoint_var = 0):
470 | padding_len = (self.ksize - 1) // 2
471 |
472 | xyz = torch.nn.functional.pad(xyz, (padding_len, padding_len, padding_len, padding_len), mode='constant', value=0)
473 | feat = torch.nn.functional.pad(feat, (padding_len, padding_len, padding_len, padding_len), mode='constant', value=0)
474 | mask = torch.nn.functional.pad(mask, (padding_len, padding_len, padding_len, padding_len), mode='constant', value=0)
475 |
476 | xyz = xyz.unfold(3, self.ksize, self.stride).unfold(4, self.ksize, self.stride)
477 | feat = feat.unfold(3, self.ksize, self.stride).unfold(4, self.ksize, self.stride)
478 | mask = mask.unfold(3, self.ksize, self.stride).unfold(4, self.ksize, self.stride)
479 |
480 | xyz = xyz.permute(1, 3, 4, 5, 6, 0, 2) # (bs, h, w, k, k, tl, c)
481 | feat = feat.permute(1, 3, 4, 5, 6, 0, 2)
482 | mask = mask.permute(1, 3, 4, 5, 6, 0, 2)
483 |
484 | bs_h_w = xyz.shape[0:3]
485 | xyz = xyz.reshape(-1, self.ksize * self.ksize * self.tl, xyz.shape[-1]) # (?, k*k*tl, c)
486 | feat = feat.reshape(-1, self.ksize * self.ksize * self.tl, feat.shape[-1])
487 | mask = mask.reshape(-1, self.ksize * self.ksize * self.tl, mask.shape[-1])
488 |
489 | xyz_plane = xyz / (xyz[..., 2:] + 1e-12)
490 |
491 | tidx = ((self.ksize ** 2) // 2) * self.tl
492 | xyz_local = xyz - xyz[:, tidx:tidx + 1, :]
493 | xyz_plane_local = xyz_plane - xyz_plane[:, tidx:tidx + 1, :]
494 |
495 | xyz_sq = (xyz_plane_local ** 2).sum(dim=-1, keepdim=True)
496 |
497 | xyz_max_copy = (mask * xyz_sq) + (1 - mask) * (xyz_sq.max() + 1)
498 | _, neighbors_ind = torch.topk(xyz_max_copy, self.neighbors, dim= 1, largest=False, sorted=False)
499 | xyz_neighbors = torch.gather(xyz_local, dim = 1, index= neighbors_ind.expand(-1, -1, xyz_local.shape[-1]))
500 | feat_neighbors = torch.gather(feat, dim = 1, index= neighbors_ind.expand(-1, -1, feat.shape[-1]))
501 |
502 | out_dense1 = self.dense1(xyz_neighbors)
503 | out_dense2 = self.dense2(out_dense1)
504 |
505 | feat_weighted = (out_dense2 * feat_neighbors).sum(dim = 1)
506 |
507 | out_conv = torch.matmul(feat_weighted, self.w).view(*bs_h_w, self.channels_out).permute(0, 3, 1, 2)
508 | out_conv = self.activation(out_conv)
509 |
510 | out = self.bn(out_conv)
511 |
512 | return out
513 |
514 | class ResNetBlock(TimedModule):
515 | def __init__(self, planes):
516 | super(ResNetBlock, self).__init__(mod_name='ResNetBlock')
517 |
518 | self.pad = torch.nn.ZeroPad2d(1)
519 |
520 | self.conv1 = torch.nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=0)
521 | # self.bn1 = torch.nn.BatchNorm2d(planes)
522 | self.bn1 = torch.nn.GroupNorm(num_groups= 1, num_channels= planes)
523 | self.relu1 = torch.nn.SELU(inplace=True)
524 | self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=0)
525 | # self.bn2 = torch.nn.BatchNorm2d(planes)
526 | self.bn2 = torch.nn.GroupNorm(num_groups= 1, num_channels= planes)
527 | self.relu2 = torch.nn.SELU(inplace=True)
528 |
529 | def forward(self, x):
530 | identity = x.clone()
531 |
532 | out = self.conv1(self.pad(x))
533 | out = self.relu1(out)
534 | out = self.bn1(out)
535 |
536 | out = self.conv2(self.pad(out))
537 | out = self.bn2(out)
538 |
539 | out += identity
540 | out = self.relu2(out)
541 |
542 | return out
--------------------------------------------------------------------------------
/model/multi_frame_worker.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # MIT License
5 | #
6 | # Copyright, 2021 ams International AG
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 | import torch
27 | import numpy as np
28 | import logging
29 | import itertools
30 | import matplotlib.pyplot as plt
31 | import co
32 |
33 | from data import base_dataset
34 | from data import dataset
35 |
36 | from model import networks
37 |
38 | from . import worker
39 |
40 | class Worker(worker.Worker):
41 | def __init__(self, args, **kwargs):
42 | super().__init__(args, **kwargs)
43 |
44 | self.disparity_loss = networks.DisparitySmoothLoss()
45 |
46 | def get_train_set(self):
47 | train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=self.track_length, load_flow_data = True, load_primary_data = True, load_pseudo_gt = False, data_type = self.data_type)
48 | return train_set
49 |
50 | def get_test_sets(self):
51 | test_sets = base_dataset.TestSets()
52 | test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=False, track_length=self.track_length, load_flow_data = True, load_primary_data = True, load_pseudo_gt = self.use_pseudo_gt, data_type = self.data_type)
53 | test_sets.append('simple', test_set, test_frequency=1)
54 |
55 | self.patterns = []
56 | self.ph_losses = []
57 | self.ge_losses = []
58 | self.d2ds = []
59 |
60 | self.lcn_in = self.lcn_in.to('cuda')
61 | for sidx in range(len(test_set.imsizes)):
62 | imsize = test_set.imsizes[sidx]
63 | pat = test_set.patterns[sidx]
64 | pat = pat.mean(axis=2)
65 | pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda')
66 | pat,_ = self.lcn_in(pat)
67 |
68 | self.patterns.append(pat)
69 |
70 | pat = torch.cat([pat for idx in range(3)], dim=1)
71 | ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat)
72 |
73 | K = test_set.getK(sidx)
74 | Ki = np.linalg.inv(K)
75 | K = torch.from_numpy(K)
76 | Ki = torch.from_numpy(Ki)
77 | ge_loss = networks.Multi_Frame_Flow_Consistency_Loss(K, Ki, imsize[0], imsize[1], clamp=0.1)
78 |
79 | self.ph_losses.append( ph_loss )
80 | self.ge_losses.append( ge_loss )
81 |
82 | d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline))
83 | self.d2ds.append( d2d )
84 |
85 | return test_sets
86 |
87 | def net_forward(self, net, flow):
88 | im0 = self.data['im0']
89 | ambient0 = self.data['ambient0']
90 | disp0 = self.data['primary_disp']
91 | R = self.data['R']
92 | t = self.data['t']
93 |
94 | d2d = self.d2ds[0]
95 | depth = d2d(disp0)
96 |
97 | self.primary_disp = disp0[0, 0, 0, ...].detach().cpu().numpy()
98 |
99 | out = net(im0, ambient0, disp0, depth, R, t, flow)
100 |
101 | return out
102 |
103 | def loss_forward(self, out, train, flow_out = None):
104 | if not(isinstance(out, tuple) or isinstance(out, list)):
105 | out = [out]
106 |
107 | vals = []
108 |
109 | # apply photometric loss
110 | for s,o in zip(itertools.count(), out):
111 | im = self.data[f'im0']
112 | im = im.view(-1, *im.shape[2:])
113 | o = o.view(-1, *o.shape[2:])
114 | std = self.data[f'std0']
115 | std = std.view(-1, *std.shape[2:])
116 | val, pattern_proj = self.ph_losses[0](o, im[:,0:1,...], std)
117 | vals.append(val / (2 ** s))
118 |
119 | # apply disparity loss
120 | for s, o in zip(itertools.count(), out):
121 | if s == 0:
122 | amb0 = self.data[f'ambient0']
123 | amb0 = amb0.contiguous().view(-1, *amb0.shape[2:])
124 | o = o.view(-1, *o.shape[2:])
125 | val = self.disparity_loss(o, amb0)
126 | vals.append(val * 0.8 / (2 ** s))
127 |
128 | # apply geometric loss
129 | self.flow_mask = None
130 | R = self.data['R']
131 | t = self.data['t']
132 | primary_disp = self.data['primary_disp']
133 | amb = self.data['ambient0']
134 |
135 | ge_num = self.track_length * (self.track_length-1) / 2
136 | for sidx in range(1):
137 | d2d = self.d2ds[0]
138 | depth = d2d(out[sidx])
139 | primary_depth = d2d(primary_disp)
140 | ge_loss = self.ge_losses[0]
141 | for tidx0 in range(depth.shape[0]):
142 | for tidx1 in range(tidx0+1, depth.shape[0]):
143 | depth0 = depth[tidx0]
144 | R0 = R[tidx0]
145 | t0 = t[tidx0]
146 | amb0 = amb[tidx0]
147 | primary_depth0 = primary_depth[tidx0]
148 | flow0 = flow_out[f'flow_{tidx0}{tidx1}']
149 | depth1 = depth[tidx1]
150 | R1 = R[tidx1]
151 | t1 = t[tidx1]
152 | amb1 = amb[tidx1]
153 | primary_depth1 = primary_depth[tidx1]
154 | flow1 = flow_out[f'flow_{tidx1}{tidx0}']
155 |
156 | val = ge_loss(depth0, depth1, R0, t0, R1, t1, flow0, flow1, amb0, amb1, primary_depth0, primary_depth1)
157 | vals.append(val * 0.2 / ge_num / (2 ** sidx))
158 |
159 | # warming up the network for a few epochs
160 | if train:
161 | if self.current_epoch < 2:
162 | for s, o in zip(itertools.count(), out):
163 | if s == 0:
164 | val = torch.mean(torch.abs(o - self.data['primary_disp']))
165 | vals.append(val * 0.1)
166 |
167 | # warming up the network for a few epochs
168 | if self.current_epoch < self.warmup_epochs and self.data_type == 'real':
169 | for s, o in zip(itertools.count(), out):
170 | if s == 0:
171 | valid_mask = (self.data['sgm_disp'] > 30).float()
172 | val = torch.sum(torch.abs(o - self.data['sgm_disp'] + 1.5 * torch.randn(o.size()).cuda()) * valid_mask) / torch.sum(valid_mask)
173 | vals.append(val * 0.1)
174 |
175 | return vals
176 |
177 | def numpy_in_out(self, output):
178 | if not(isinstance(output, tuple) or isinstance(output, list)):
179 | output = [output]
180 | es = output[0].detach().to('cpu').numpy()
181 | gt = self.data['disp0'].detach().to('cpu').numpy().astype(np.float32)
182 | im = self.data['im0'][:,:,0:1,...].detach().to('cpu').numpy()
183 | amb = self.data['ambient0'].detach().to('cpu').numpy()
184 | pat = self.patterns[0].detach().to('cpu').numpy()
185 |
186 | es = es * (gt > 0)
187 |
188 | return es, gt, im, amb, pat
189 |
190 | def write_img(self, out_path, es, gt, im, amb, pat):
191 | logging.info(f'write img {out_path}')
192 |
193 | diff = np.abs(es - gt)
194 |
195 | vmin, vmax = np.nanmin(gt), np.nanmax(gt)
196 | vmin = vmin - 0.2*(vmax-vmin)
197 | vmax = vmax + 0.2*(vmax-vmin)
198 |
199 | vmax = np.max([vmax, 16])
200 |
201 | fig = plt.figure(figsize=(16,16))
202 | # plot pattern and input images
203 | ax = plt.subplot(3,3,1); plt.imshow(pat, vmin=pat.min(), vmax=pat.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Projector Pattern')
204 | ax = plt.subplot(3,3,2); plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 IR Input')
205 | ax = plt.subplot(3,3,3); plt.imshow(amb[0], vmin=amb.min(), vmax=amb.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Ambient Input')
206 |
207 | # plot disparities, ground truth disparity is shown only for reference
208 | es0 = co.cmap.color_depth_map(es[0], scale=vmax)
209 | gt0 = co.cmap.color_depth_map(gt[0], scale=vmax)
210 | diff0 = co.cmap.color_error_image(diff[0], BGR=True)
211 |
212 | ax = plt.subplot(3,3,4); plt.imshow(gt0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity GT {np.nanmin(gt[0]):.4f}/{np.nanmax(gt[0]):.4f}')
213 | ax = plt.subplot(3,3,5); plt.imshow(es0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Est. {es[0].min():.4f}/{es[0].max():.4f}')
214 | ax = plt.subplot(3,3,6); plt.imshow(diff0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Err. {diff[0].mean():.5f}')
215 |
216 | es1 = co.cmap.color_depth_map(self.primary_disp, scale=vmax)
217 | gt1 = co.cmap.color_depth_map(gt[0], scale=vmax)
218 | diff1_np = np.abs(self.primary_disp - gt[0])
219 | diff1 = co.cmap.color_error_image(diff1_np, BGR=True)
220 | ax = plt.subplot(3,3,7); plt.imshow(gt1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity GT {np.nanmin(gt[0]):.4f}/{np.nanmax(gt[0]):.4f}')
221 | ax = plt.subplot(3,3,8); plt.imshow(es1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Input. {self.primary_disp.min():.4f}/{self.primary_disp.max():.4f}')
222 | ax = plt.subplot(3,3,9); plt.imshow(diff1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Input Err. {diff1_np.mean():.5f}')
223 |
224 | plt.tight_layout()
225 | plt.savefig(str(out_path))
226 | plt.close(fig)
227 |
228 | def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
229 | if batch_idx % 256 == 0:
230 | out_path = self.exp_output_dir / f'train_{epoch:03d}_{batch_idx:04d}.png'
231 | es, gt, im, amb, pat = self.numpy_in_out(output)
232 | self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], amb[:, 0, 0], pat[0, 0])
233 | torch.cuda.empty_cache()
234 |
235 | def callback_test_start(self, epoch, set_idx):
236 | self.metric = co.metric.MultipleMetric(
237 | co.metric.DistanceMetric(vec_length=1),
238 | co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
239 | )
240 |
241 | def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
242 | es, gt, im, amb, pat = self.numpy_in_out(output)
243 |
244 | if batch_idx % 8 == 0:
245 | out_path = self.exp_output_dir / f'test_{epoch:03d}_{batch_idx:04d}.png'
246 | self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], amb[:, 0, 0], pat[0, 0])
247 |
248 | es = self.crop_reshape(es)
249 | gt = self.crop_reshape(gt)
250 |
251 | self.metric.add(es, gt)
252 |
253 | def crop_reshape(self, input):
254 | output = input.reshape(-1, 1)
255 | return output
256 |
257 | def callback_test_stop(self, epoch, set_idx, loss):
258 | logging.info(f'{self.metric}')
259 | for k, v in self.metric.items():
260 | self.metric_add_test(epoch, set_idx, k, v)
261 |
262 | if __name__ == '__main__':
263 | pass
264 |
--------------------------------------------------------------------------------
/model/networks.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | #
5 | # MIT License
6 | #
7 | # Copyright (c) 2019 autonomousvision
8 | #
9 | # Permission is hereby granted, free of charge, to any person obtaining a copy
10 | # of this software and associated documentation files (the "Software"), to deal
11 | # in the Software without restriction, including without limitation the rights
12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 | # copies of the Software, and to permit persons to whom the Software is
14 | # furnished to do so, subject to the following conditions:
15 | #
16 | # The above copyright notice and this permission notice shall be included in all
17 | # copies or substantial portions of the Software.
18 | #
19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 | # SOFTWARE.
26 | #
27 | #
28 | # MIT License
29 | #
30 | # Copyright, 2021 ams International AG
31 | #
32 | # Permission is hereby granted, free of charge, to any person obtaining a copy
33 | # of this software and associated documentation files (the "Software"), to deal
34 | # in the Software without restriction, including without limitation the rights
35 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
36 | # copies of the Software, and to permit persons to whom the Software is
37 | # furnished to do so, subject to the following conditions:
38 | #
39 | # The above copyright notice and this permission notice shall be included in all
40 | # copies or substantial portions of the Software.
41 | #
42 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
43 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
44 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
45 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
46 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
47 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
48 | # SOFTWARE.
49 |
50 | import torch
51 | import torch.nn.functional as F
52 | import numpy as np
53 |
54 | import co
55 |
56 | from . import ext_functions
57 |
58 | class TimedModule(torch.nn.Module):
59 | def __init__(self, mod_name):
60 | super().__init__()
61 | self.mod_name = mod_name
62 |
63 | def tforward(self, *args, **kwargs):
64 | raise Exception('not implemented')
65 |
66 | def forward(self, *args, **kwargs):
67 | torch.cuda.synchronize()
68 | with co.gtimer.Ctx(self.mod_name):
69 | x = self.tforward(*args, **kwargs)
70 | torch.cuda.synchronize()
71 | return x
72 |
73 |
74 | class PosOutput(TimedModule):
75 | def __init__(self, channels_in, type, im_height, im_width, alpha=1, beta=0, gamma=1, offset=0):
76 | super().__init__(mod_name='PosOutput')
77 | self.im_width = im_width
78 | self.im_width = im_width
79 |
80 | if type == 'pos':
81 | self.layer = torch.nn.Sequential(
82 | torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
83 | SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
84 | )
85 | elif type == 'pos_row':
86 | self.layer = torch.nn.Sequential(
87 | MultiLinear(im_height, channels_in, 1),
88 | SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
89 | )
90 |
91 | self.u_pos = None
92 |
93 | def tforward(self, x):
94 | if self.u_pos is None:
95 | self.u_pos = torch.arange(x.shape[3], dtype=torch.float32).view(1,1,1,-1)
96 | self.u_pos = self.u_pos.to(x.device)
97 | pos = self.layer(x)
98 | disp = self.u_pos - pos
99 | return disp
100 |
101 |
102 | class OutputLayerFactory(object):
103 | '''
104 | Define type of output
105 | type options:
106 | linear: apply only conv channel, used for the edge decoder
107 | disp: estimate the disparity
108 | disp_row: independently estimate the disparity per row
109 | pos: estimate the absolute location
110 | pos_row: independently estimate the absolute location per row
111 | '''
112 | def __init__(self, type='disp', params={}):
113 | self.type = type
114 | self.params = params
115 |
116 | def __call__(self, channels_in, imsize = None):
117 |
118 | if self.type == 'linear':
119 | return torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1)
120 |
121 | elif self.type == 'disp':
122 | return torch.nn.Sequential(
123 | torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
124 | SigmoidAffine(**self.params)
125 | )
126 |
127 | elif self.type == 'disp_row':
128 | return torch.nn.Sequential(
129 | MultiLinear(imsize[0], channels_in, 1),
130 | SigmoidAffine(**self.params)
131 | )
132 |
133 | elif self.type == 'pos' or self.type == 'pos_row':
134 | return PosOutput(channels_in, **self.params)
135 |
136 | else:
137 | raise Exception('unknown output layer type')
138 |
139 |
140 | class SigmoidAffine(TimedModule):
141 | def __init__(self, alpha=1, beta=0, gamma=1, offset=0):
142 | super().__init__(mod_name='SigmoidAffine')
143 | self.alpha = alpha
144 | self.beta = beta
145 | self.gamma = gamma
146 | self.offset = offset
147 |
148 | def tforward(self, x):
149 | return torch.sigmoid(x/self.gamma - self.offset) * self.alpha + self.beta
150 |
151 |
152 | class MultiLinear(TimedModule):
153 | def __init__(self, n, channels_in, channels_out):
154 | super().__init__(mod_name='MultiLinear')
155 | self.channels_out = channels_out
156 | self.mods = torch.nn.ModuleList()
157 | for idx in range(n):
158 | self.mods.append(torch.nn.Linear(channels_in, channels_out))
159 |
160 | def tforward(self, x):
161 | x = x.permute(2,0,3,1) # BxCxHxW => HxBxWxC
162 | y = x.new_empty(*x.shape[:-1], self.channels_out)
163 | for hidx in range(x.shape[0]):
164 | y[hidx] = self.mods[hidx](x[hidx])
165 | y = y.permute(1,3,0,2) # HxBxWxC => BxCxHxW
166 | return y
167 |
168 |
169 |
170 | class DispNetS(TimedModule):
171 | '''
172 | Disparity Decoder based on DispNetS
173 | '''
174 | def __init__(self, channels_in, imsizes, output_facs, coordconv=False, weight_init=False, channel_multiplier=1):
175 | super(DispNetS, self).__init__(mod_name='DispNetS')
176 |
177 | conv_planes = channel_multiplier * np.array( [32, 64, 128, 256, 512, 512, 512] )
178 | self.conv1 = self.downsample_conv(channels_in, conv_planes[0], kernel_size=7)
179 | self.conv2 = self.downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5)
180 | self.conv3 = self.downsample_conv(conv_planes[1], conv_planes[2])
181 | self.conv4 = self.downsample_conv(conv_planes[2], conv_planes[3])
182 | self.conv5 = self.downsample_conv(conv_planes[3], conv_planes[4])
183 | self.conv6 = self.downsample_conv(conv_planes[4], conv_planes[5])
184 | self.conv7 = self.downsample_conv(conv_planes[5], conv_planes[6])
185 |
186 | upconv_planes = channel_multiplier * np.array( [512, 512, 256, 128, 64, 32, 16] )
187 | self.upconv7 = self.upconv(conv_planes[6], upconv_planes[0])
188 | self.upconv6 = self.upconv(upconv_planes[0], upconv_planes[1])
189 | self.upconv5 = self.upconv(upconv_planes[1], upconv_planes[2])
190 | self.upconv4 = self.upconv(upconv_planes[2], upconv_planes[3])
191 | self.upconv3 = self.upconv(upconv_planes[3], upconv_planes[4])
192 | self.upconv2 = self.upconv(upconv_planes[4], upconv_planes[5])
193 | self.upconv1 = self.upconv(upconv_planes[5], upconv_planes[6])
194 |
195 | self.iconv7 = self.conv(upconv_planes[0] + conv_planes[5], upconv_planes[0])
196 | self.iconv6 = self.conv(upconv_planes[1] + conv_planes[4], upconv_planes[1])
197 | self.iconv5 = self.conv(upconv_planes[2] + conv_planes[3], upconv_planes[2])
198 | self.iconv4 = self.conv(upconv_planes[3] + conv_planes[2], upconv_planes[3])
199 | self.iconv3 = self.conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4])
200 | self.iconv2 = self.conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5])
201 | self.iconv1 = self.conv(1 + upconv_planes[6], upconv_planes[6])
202 |
203 |
204 | if isinstance(output_facs, list):
205 | self.predict_disp4 = output_facs[3](upconv_planes[3], imsizes[3])
206 | self.predict_disp3 = output_facs[2](upconv_planes[4], imsizes[2])
207 | self.predict_disp2 = output_facs[1](upconv_planes[5], imsizes[1])
208 | self.predict_disp1 = output_facs[0](upconv_planes[6], imsizes[0])
209 | else:
210 | self.predict_disp4 = output_facs(upconv_planes[3], imsizes[3])
211 | self.predict_disp3 = output_facs(upconv_planes[4], imsizes[2])
212 | self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1])
213 | self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0])
214 |
215 | # def init_weights(self):
216 | # for m in self.modules():
217 | # if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
218 | # torch.nn.init.xavier_uniform_(m.weight, gain=0.1)
219 | # if m.bias is not None:
220 | # torch.nn.init.zeros_(m.bias)
221 |
222 | def downsample_conv(self, in_planes, out_planes, kernel_size=3):
223 | return torch.nn.Sequential(
224 | torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2),
225 | torch.nn.ReLU(inplace=True),
226 | torch.nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2),
227 | torch.nn.ReLU(inplace=True)
228 | )
229 |
230 | def conv(self, in_planes, out_planes):
231 | return torch.nn.Sequential(
232 | torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1),
233 | torch.nn.ReLU(inplace=True)
234 | )
235 |
236 | def upconv(self, in_planes, out_planes):
237 | return torch.nn.Sequential(
238 | torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1),
239 | torch.nn.ReLU(inplace=True)
240 | )
241 |
242 | def crop_like(self, input, ref):
243 | assert(input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3))
244 | return input[:, :, :ref.size(2), :ref.size(3)]
245 |
246 | def tforward(self, x):
247 | out_conv1 = self.conv1(x)
248 | out_conv2 = self.conv2(out_conv1)
249 | out_conv3 = self.conv3(out_conv2)
250 | out_conv4 = self.conv4(out_conv3)
251 | out_conv5 = self.conv5(out_conv4)
252 | out_conv6 = self.conv6(out_conv5)
253 | out_conv7 = self.conv7(out_conv6)
254 |
255 | out_upconv7 = self.crop_like(self.upconv7(out_conv7), out_conv6)
256 | concat7 = torch.cat((out_upconv7, out_conv6), 1)
257 | out_iconv7 = self.iconv7(concat7)
258 |
259 | out_upconv6 = self.crop_like(self.upconv6(out_iconv7), out_conv5)
260 | concat6 = torch.cat((out_upconv6, out_conv5), 1)
261 | out_iconv6 = self.iconv6(concat6)
262 |
263 | out_upconv5 = self.crop_like(self.upconv5(out_iconv6), out_conv4)
264 | concat5 = torch.cat((out_upconv5, out_conv4), 1)
265 | out_iconv5 = self.iconv5(concat5)
266 |
267 | out_upconv4 = self.crop_like(self.upconv4(out_iconv5), out_conv3)
268 | concat4 = torch.cat((out_upconv4, out_conv3), 1)
269 | out_iconv4 = self.iconv4(concat4)
270 | disp4 = self.predict_disp4(out_iconv4)
271 |
272 | out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2)
273 | disp4_up = self.crop_like(torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2)
274 | concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1)
275 | out_iconv3 = self.iconv3(concat3)
276 | disp3 = self.predict_disp3(out_iconv3)
277 |
278 | out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
279 | disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
280 | concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
281 | out_iconv2 = self.iconv2(concat2)
282 | disp2 = self.predict_disp2(out_iconv2)
283 |
284 | out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
285 | disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
286 | concat1 = torch.cat((out_upconv1, disp2_up), 1)
287 | out_iconv1 = self.iconv1(concat1)
288 | disp1 = self.predict_disp1(out_iconv1)
289 |
290 | out1 = disp1
291 | out2 = torch.nn.functional.interpolate(input=disp2, size=(out1.size(2), out1.size(3)), mode='bilinear', align_corners=False)
292 | out3 = torch.nn.functional.interpolate(input=disp3, size=(out1.size(2), out1.size(3)), mode='bilinear', align_corners=False)
293 | out4 = torch.nn.functional.interpolate(input=disp4, size=(out1.size(2), out1.size(3)), mode='bilinear', align_corners=False)
294 |
295 | return (out1, out2, out3, out4)
296 |
297 | class DispDecoder(TimedModule):
298 | '''
299 | Disparity Decoder
300 | '''
301 | def __init__(self, *args, max_disp=128, **kwargs):
302 | super(DispDecoder, self).__init__(mod_name='DispDecoder')
303 |
304 | output_facs_disp = [OutputLayerFactory( type='disp', params={ 'alpha': max_disp/(2**s), 'beta': 0, 'gamma': 1, 'offset': 3}) for s in range(4)]
305 | self.disp_decoder = DispNetS(*args, output_facs=output_facs_disp, **kwargs)
306 |
307 | def tforward(self, x):
308 | disp = self.disp_decoder(x)
309 | return disp
310 |
311 | class DispToDepth(TimedModule):
312 | def __init__(self, focal_length, baseline):
313 | super().__init__(mod_name='DispToDepth')
314 | self.baseline_focal_length = baseline * focal_length
315 |
316 | def tforward(self, disp):
317 | disp = torch.nn.functional.relu(disp) + 1e-12
318 | depth = self.baseline_focal_length / disp
319 | return depth
320 |
321 | class PosToDepth(DispToDepth):
322 | def __init__(self, focal_length, baseline, im_height, im_width):
323 | super().__init__(focal_length, baseline)
324 | self.mod_name = 'PosToDepth'
325 |
326 | self.im_height = im_height
327 | self.im_width = im_width
328 | self.u_pos = torch.arange(im_width, dtype=torch.float32).view(1,1,1,-1)
329 |
330 | def tforward(self, pos):
331 | self.u_pos = self.u_pos.to(pos.device)
332 | disp = self.u_pos - pos
333 | return super().forward(disp)
334 |
335 |
336 | class RectifiedPatternSimilarityLoss(TimedModule):
337 | '''
338 | Photometric Loss
339 | '''
340 | def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5):
341 | super().__init__(mod_name='RectifiedPatternSimilarityLoss')
342 | self.im_height = im_height
343 | self.im_width = im_width
344 | self.pattern = pattern.mean(dim=1, keepdim=True).contiguous()
345 |
346 | u, v = np.meshgrid(range(im_width), range(im_height))
347 | uv0 = np.stack((u,v), axis=2).reshape(-1,1)
348 | uv0 = uv0.astype(np.float32).reshape(1,-1,2)
349 | self.uv0 = torch.from_numpy(uv0)
350 |
351 | self.loss_type = loss_type
352 | self.loss_eps = loss_eps
353 |
354 | def tforward(self, disp0, im, std=None, output_mean=True):
355 | self.pattern = self.pattern.to(disp0.device)
356 | self.uv0 = self.uv0.to(disp0.device)
357 |
358 | uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:])
359 | uv1 = torch.empty_like(uv0)
360 | uv1[...,0] = uv0[...,0] - disp0.contiguous().view(disp0.shape[0],-1)
361 | uv1[...,1] = uv0[...,1]
362 |
363 | uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5)
364 | uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5)
365 | uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
366 | pattern = self.pattern.expand(disp0.shape[0], *self.pattern.shape[1:])
367 | pattern_proj = torch.nn.functional.grid_sample(pattern, uv1, padding_mode='border', align_corners=True)
368 | mask = torch.ones_like(im)
369 | if std is not None:
370 | mask = mask*std
371 |
372 | diff = ext_functions.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps)
373 | if output_mean:
374 | val = (mask*diff).sum() / mask.sum()
375 | else:
376 | val = diff
377 | return val, pattern_proj
378 |
379 | class SSIM(TimedModule):
380 | """Layer to compute the SSIM loss between a pair of images
381 | """
382 | def __init__(self):
383 | super().__init__(mod_name='SSIM')
384 | self.mu_x_pool = torch.nn.AvgPool2d(3, 1)
385 | self.mu_y_pool = torch.nn.AvgPool2d(3, 1)
386 | self.sig_x_pool = torch.nn.AvgPool2d(3, 1)
387 | self.sig_y_pool = torch.nn.AvgPool2d(3, 1)
388 | self.sig_xy_pool = torch.nn.AvgPool2d(3, 1)
389 |
390 | self.refl = torch.nn.ReflectionPad2d(1)
391 |
392 | self.C1 = 0.01 ** 2
393 | self.C2 = 0.03 ** 2
394 |
395 | def tforward(self, x, y):
396 | x = self.refl(x)
397 | y = self.refl(y)
398 |
399 | mu_x = self.mu_x_pool(x)
400 | mu_y = self.mu_y_pool(y)
401 |
402 | sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2
403 | sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2
404 | sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
405 |
406 | SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
407 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)
408 |
409 | return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
410 |
411 | class DisparitySmoothLoss(TimedModule):
412 | '''
413 | Depth Smooth Loss
414 | '''
415 | def __init__(self):
416 | super().__init__(mod_name='DepthSmoothLoss')
417 | self.sobel = SobelFilter(norm=False, ksize=5)
418 |
419 | def tforward(self, disp, im):
420 | self.sobel=self.sobel.to(disp.device)
421 |
422 | mean_disp = disp.mean(2, True).mean(3, True)
423 | # norm_disp = disp / (mean_disp + 1e-7)
424 | norm_disp = disp
425 |
426 | grad = self.sobel(norm_disp)
427 | grad_im = self.sobel(im)
428 |
429 | val = torch.abs(grad * torch.exp(-torch.abs(255 * grad_im)))
430 |
431 | return val.mean()
432 |
433 | class ProjectionBaseLoss(TimedModule):
434 | '''
435 | Base module of the Geometric Loss
436 | '''
437 | def __init__(self, K, Ki, im_height, im_width):
438 | super().__init__(mod_name='ProjectionBaseLoss')
439 |
440 | self.K = K.view(-1,3,3)
441 |
442 | self.im_height = im_height
443 | self.im_width = im_width
444 |
445 | u, v = np.meshgrid(range(im_width), range(im_height))
446 | uv = np.stack((u,v,np.ones_like(u)), axis=2).reshape(-1,3)
447 |
448 | ray = uv @ Ki.numpy().T
449 |
450 | ray = ray.reshape(1,-1,3).astype(np.float32)
451 | self.ray = torch.from_numpy(ray)
452 | self.u = torch.from_numpy(u.astype('float32'))
453 | self.v = torch.from_numpy(v.astype('float32'))
454 |
455 | def transform(self, xyz, R=None, t=None):
456 | if t is not None:
457 | bs = xyz.shape[0]
458 | xyz = xyz - t.reshape(bs,1,3)
459 | if R is not None:
460 | xyz = torch.bmm(xyz, R)
461 | return xyz
462 |
463 | def unproject(self, depth, R=None, t=None):
464 | self.ray = self.ray.to(depth.device)
465 | bs = depth.shape[0]
466 |
467 | xyz = depth.reshape(bs,-1,1) * self.ray
468 | xyz = self.transform(xyz, R, t)
469 | return xyz
470 |
471 | def project(self, xyz, R, t, return_ray_format= False):
472 | self.K = self.K.to(xyz.device)
473 | bs = xyz.shape[0]
474 |
475 | xyz = torch.bmm(xyz, R.transpose(1,2))
476 | xyz = xyz + t.reshape(bs,1,3)
477 |
478 | if return_ray_format:
479 | uv = xyz
480 | else:
481 | Kt = self.K.transpose(1,2).expand(bs,-1,-1)
482 | uv = torch.bmm(xyz, Kt)
483 |
484 | d = uv[:,:,2:3]
485 |
486 | # avoid division by zero
487 | uv = uv[:,:,:2] / (torch.nn.functional.relu(d) + 1e-12)
488 | return uv, d
489 |
490 |
491 | def tforward(self, depth0, R0, t0, R1, t1, return_ray_format= False):
492 | xyz = self.unproject(depth0, R0, t0)
493 | return self.project(xyz, R1, t1, return_ray_format)
494 |
495 |
496 | class ProjectionDepthSimilarityLoss(ProjectionBaseLoss):
497 | '''
498 | Geometric Loss
499 | '''
500 | def __init__(self, *args, clamp=-1):
501 | super().__init__(*args)
502 | self.mod_name = 'ProjectionDepthSimilarityLoss'
503 | self.clamp = clamp
504 |
505 | def fwd(self, depth0, depth1, R0, t0, R1, t1):
506 | self.u = self.u.to(depth0.device)
507 | self.v = self.v.to(depth0.device)
508 |
509 | uv1, d1 = super().tforward(depth0, R0, t0, R1, t1)
510 | uv1 = uv1.view(-1, self.im_height, self.im_width, 2)
511 | d1 = d1.view(-1, 1, self.im_height, self.im_width)
512 |
513 | # Calculation rigid flow
514 | rigid_flow = uv1.clone()
515 | rigid_flow[..., 0] -= self.u
516 | rigid_flow[..., 1] -= self.v
517 | rigid_flow = rigid_flow.permute(0, 3, 1, 2)
518 |
519 | uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5)
520 | uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5)
521 | depth10 = torch.nn.functional.grid_sample(depth1, uv1, padding_mode='border', align_corners=True)
522 |
523 | diff = torch.abs(d1 - depth10)
524 |
525 | if self.clamp > 0:
526 | diff = torch.clamp(diff, 0, self.clamp)
527 |
528 | orig_mask = (diff.detach() < self.clamp).to('cpu').numpy().astype('float32')
529 |
530 | return diff.mean(), rigid_flow, orig_mask[0][0]
531 |
532 | def tforward(self, depth0, depth1, R0, t0, R1, t1):
533 | l0, rigid_flow0, orig_mask = self.fwd(depth0, depth1, R0, t0, R1, t1)
534 | l1, rigid_flow1, _ = self.fwd(depth1, depth0, R1, t1, R0, t0)
535 |
536 | with torch.no_grad():
537 | mask0 = self.generate_mask(rigid_flow0.detach(), rigid_flow1.detach())
538 | mask1 = self.generate_mask(rigid_flow1.detach(), rigid_flow0.detach())
539 |
540 | return l0+l1, rigid_flow0, rigid_flow1, mask0, mask1, orig_mask
541 |
542 | def generate_mask(self, flow0, flow1):
543 | uv1 = flow0.clone().permute(0, 2, 3, 1)
544 | uv1[..., 0] += self.u
545 | uv1[..., 1] += self.v
546 | uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5)
547 | uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5)
548 | flow0_proj = torch.nn.functional.grid_sample(flow1, uv1, padding_mode='border', align_corners=True)
549 | mask0 = ((flow0 + flow0_proj) ** 2).sum(dim=1) < 0.25 + 0.02 * ((flow0 ** 2).sum(dim=1) + (flow0_proj ** 2).sum(dim=1))
550 | mask0 = mask0.type(torch.float32).unsqueeze(1)
551 | return mask0
552 |
553 |
554 | class Multi_Frame_Flow_Consistency_Loss(ProjectionBaseLoss):
555 | '''
556 | Flow Consistency Loss for Multi-Frame Inference
557 | '''
558 |
559 | def __init__(self, *args, clamp=-1):
560 | super().__init__(*args)
561 | self.mod_name = 'Multi_Frame_Flow_Consistency_Loss'
562 | self.clamp = clamp
563 |
564 | def fwd(self, depth0, depth1, R0, t0, R1, t1, flow0, flow1, amb0, amb1, primary_depth1):
565 | self.u = self.u.to(depth0.device)
566 | self.v = self.v.to(depth0.device)
567 |
568 | uv1, d1 = super().tforward(depth0, R0, t0, R1, t1)
569 | uv1 = uv1.view(-1, self.im_height, self.im_width, 2)
570 | d1 = d1.view(-1, 1, self.im_height, self.im_width)
571 |
572 | uv1_flow = flow0.permute(0, 2, 3, 1).clone()
573 | uv1_flow[..., 0] += self.u
574 | uv1_flow[..., 1] += self.v
575 |
576 | uv1_flow[..., 0] = 2 * (uv1_flow[..., 0] / (self.im_width - 1) - 0.5)
577 | uv1_flow[..., 1] = 2 * (uv1_flow[..., 1] / (self.im_height - 1) - 0.5)
578 | depth10 = torch.nn.functional.grid_sample(depth1, uv1_flow, padding_mode='zeros', align_corners=True)
579 |
580 | diff = torch.abs(d1 - depth10)
581 |
582 | with torch.no_grad():
583 | flow10 = torch.nn.functional.grid_sample(flow1.detach(), uv1_flow.detach(), padding_mode='zeros', align_corners=True)
584 | fb_mask = ((flow0.detach() + flow10) ** 2).sum(dim=1) < 0.5 + 0.02 * (
585 | (flow0.detach() ** 2).sum(dim=1) + (flow10 ** 2).sum(dim=1))
586 | fb_mask = fb_mask.type(torch.float32).unsqueeze(1)
587 |
588 | amb10 = torch.nn.functional.grid_sample(amb1.detach(), uv1_flow.detach(), padding_mode='zeros', align_corners=True)
589 | vc_mask = (((amb0 - amb10).abs()).mean(dim=1, keepdim=True) < 0.01).type(torch.float32)
590 |
591 | uv0, d0 = super().tforward(primary_depth1.detach(), R1.detach(), t1.detach(), R0.detach(), t0.detach(), return_ray_format= False)
592 | uv0 = uv0.view(-1, self.im_height, self.im_width, 2).permute(0, 3, 1, 2)
593 | warped_uv0 = torch.nn.functional.grid_sample(uv0.detach(), uv1_flow.detach(), padding_mode='zeros', align_corners=True)
594 | self_uv = torch.stack([self.u, self.v], dim= 0).unsqueeze(0)
595 | rf_mask = (((warped_uv0 - self_uv) ** 2).sum(dim = 1, keepdim=True) < 1).type(torch.float32)
596 |
597 | loss_mask = fb_mask * vc_mask * rf_mask
598 |
599 | diff = (diff * loss_mask).sum() / (loss_mask.sum() + 1e-8)
600 |
601 | return diff
602 |
603 | def tforward(self, depth0, depth1, R0, t0, R1, t1, flow0, flow1, amb0, amb1, primary_depth0, primary_depth1):
604 | l0 = self.fwd(depth0, depth1, R0, t0, R1, t1, flow0, flow1, amb0, amb1, primary_depth1)
605 | l1 = self.fwd(depth1, depth0, R1, t1, R0, t0, flow1, flow0, amb1, amb0, primary_depth0)
606 |
607 | return l0 + l1
608 |
609 | class Single_Frame_Flow_Consistency_Loss(ProjectionBaseLoss):
610 | '''
611 | Flow Consistency Loss for Single Frame Inference
612 | '''
613 |
614 | def __init__(self, *args, clamp=-1):
615 | super().__init__(*args)
616 | self.mod_name = 'Single_Frame_Flow_Consistency_Loss'
617 | self.clamp = clamp
618 |
619 | def fwd(self, depth0, depth1, R0, t0, R1, t1, flow0, flow1, amb0, amb1):
620 | self.u = self.u.to(depth0.device)
621 | self.v = self.v.to(depth0.device)
622 |
623 | uv1, d1 = super().tforward(depth0, R0, t0, R1, t1)
624 | uv1 = uv1.view(-1, self.im_height, self.im_width, 2)
625 | d1 = d1.view(-1, 1, self.im_height, self.im_width)
626 |
627 | uv1_flow = flow0.permute(0, 2, 3, 1).clone()
628 | uv1_flow[..., 0] += self.u
629 | uv1_flow[..., 1] += self.v
630 |
631 | uv1_flow[..., 0] = 2 * (uv1_flow[..., 0] / (self.im_width - 1) - 0.5)
632 | uv1_flow[..., 1] = 2 * (uv1_flow[..., 1] / (self.im_height - 1) - 0.5)
633 | depth10 = torch.nn.functional.grid_sample(depth1, uv1_flow, padding_mode='zeros', align_corners=True)
634 |
635 | diff = torch.abs(d1 - depth10)
636 |
637 | if self.clamp > 0:
638 | diff = torch.clamp(diff, 0, self.clamp)
639 |
640 | orig_mask = (diff.detach() < self.clamp).to('cpu').numpy().astype('float32')
641 |
642 | with torch.no_grad():
643 | flow10 = torch.nn.functional.grid_sample(flow1.detach(), uv1_flow.detach(), padding_mode='zeros', align_corners=True)
644 | fb_mask = ((flow0.detach() + flow10) ** 2).sum(dim=1) < 0.5 + 0.02 * (
645 | (flow0.detach() ** 2).sum(dim=1) + (flow10 ** 2).sum(dim=1))
646 | fb_mask = fb_mask.type(torch.float32).unsqueeze(1)
647 |
648 | amb10 = torch.nn.functional.grid_sample(amb1.detach(), uv1_flow.detach(), padding_mode='zeros', align_corners=True)
649 | vc_mask = (((amb0 - amb10).abs()).mean(dim=1, keepdim=True) < 0.01).type(torch.float32)
650 |
651 | loss_mask = fb_mask * vc_mask
652 |
653 | diff = (diff * loss_mask).sum() / (loss_mask.sum() + 1e-8)
654 |
655 | return diff, loss_mask, orig_mask[0][0]
656 |
657 | def tforward(self, depth0, depth1, R0, t0, R1, t1, flow0, flow1, amb0, amb1):
658 | l0, mask0, orig_mask = self.fwd(depth0, depth1, R0, t0, R1, t1, flow0, flow1, amb0, amb1)
659 | l1, mask1, _ = self.fwd(depth1, depth0, R1, t1, R0, t0, flow1, flow0, amb1, amb0)
660 |
661 | return l0 + l1, mask0, mask1, orig_mask
662 |
663 | class LCN(TimedModule):
664 | '''
665 | Local Contract Normalization
666 | '''
667 | def __init__(self, radius, epsilon):
668 | super().__init__(mod_name='LCN')
669 | self.box_conv = torch.nn.Sequential(
670 | torch.nn.ReflectionPad2d(radius),
671 | torch.nn.Conv2d(1, 1, kernel_size=2*radius+1, bias=False)
672 | )
673 | self.box_conv[1].weight.requires_grad=False
674 | self.box_conv[1].weight.fill_(1.)
675 |
676 | self.epsilon = epsilon
677 | self.radius = radius
678 |
679 | def tforward(self, data):
680 | boxs = self.box_conv(data)
681 |
682 | avgs = boxs / (2*self.radius+1)**2
683 | boxs_n2 = boxs**2
684 | boxs_2n = self.box_conv(data**2)
685 |
686 | stds = torch.sqrt(torch.clamp(boxs_2n / (2*self.radius+1)**2 - avgs**2 + 1e-6, min=0))
687 | stds = stds + self.epsilon
688 |
689 | return (data - avgs) / stds, stds
690 |
691 |
692 |
693 | class SobelFilter(TimedModule):
694 | '''
695 | Sobel Filter
696 | '''
697 | def __init__(self, norm=False, ksize = 5):
698 | super(SobelFilter, self).__init__(mod_name='SobelFilter')
699 | self.ksize = ksize
700 | if self.ksize == 5:
701 | kx = np.array([[-5, -4, 0, 4, 5],
702 | [-8, -10, 0, 10, 8],
703 | [-10, -20, 0, 20, 10],
704 | [-8, -10, 0, 10, 8],
705 | [-5, -4, 0, 4, 5]])/240.0
706 | elif self.ksize == 3:
707 | kx = np.array([[-1, 0, 1],
708 | [-2, 0, 2],
709 | [-1, 0, 1]])/8.0
710 |
711 | ky = kx.copy().transpose(1,0)
712 |
713 | self.conv_x=torch.nn.Conv2d(1, 1, kernel_size=self.ksize, stride=1, padding=0, bias=False)
714 | self.conv_x.weight=torch.nn.Parameter(torch.from_numpy(kx).float().unsqueeze(0).unsqueeze(0))
715 |
716 | self.conv_y=torch.nn.Conv2d(1, 1, kernel_size=self.ksize, stride=1, padding=0, bias=False)
717 | self.conv_y.weight=torch.nn.Parameter(torch.from_numpy(ky).float().unsqueeze(0).unsqueeze(0))
718 |
719 | self.norm=norm
720 |
721 | def tforward(self,x):
722 | if self.ksize == 5:
723 | x = F.pad(x, (2,2,2,2), "replicate")
724 | elif self.ksize == 3:
725 | x = F.pad(x, (1,1,1,1), "replicate")
726 | gx = self.conv_x(x)
727 | gy = self.conv_y(x)
728 | if self.norm:
729 | return torch.sqrt(gx**2 + gy**2 + 1e-8)
730 | else:
731 | return torch.cat((gx, gy), dim=1)
--------------------------------------------------------------------------------
/model/single_frame_worker.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # MIT License
5 | #
6 | # Copyright, 2021 ams International AG
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 | import torch
27 | import numpy as np
28 | import logging
29 | import itertools
30 | import matplotlib.pyplot as plt
31 | import co
32 |
33 | from data import base_dataset
34 | from data import dataset
35 |
36 | from model import networks
37 |
38 | from . import worker
39 |
40 | class Worker(worker.Worker):
41 | def __init__(self, args, **kwargs):
42 | super().__init__(args, **kwargs)
43 |
44 | self.disparity_loss = networks.DisparitySmoothLoss()
45 |
46 | def get_train_set(self):
47 | train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=self.track_length, load_flow_data = True, load_primary_data = False, load_pseudo_gt = self.use_pseudo_gt, data_type = self.data_type)
48 | return train_set
49 |
50 | def get_test_sets(self):
51 | test_sets = base_dataset.TestSets()
52 | test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=False, track_length=self.track_length, load_flow_data = True, load_primary_data = False, load_pseudo_gt = self.use_pseudo_gt, data_type = self.data_type)
53 | test_sets.append('simple', test_set, test_frequency=1)
54 |
55 | self.patterns = []
56 | self.ph_losses = []
57 | self.ge_losses = []
58 | self.d2ds = []
59 |
60 | self.lcn_in = self.lcn_in.to('cuda')
61 | for sidx in range(len(test_set.imsizes)):
62 | imsize = test_set.imsizes[sidx]
63 | pat = test_set.patterns[sidx]
64 | pat = pat.mean(axis=2)
65 | pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda')
66 | pat,_ = self.lcn_in(pat)
67 |
68 | self.patterns.append(pat)
69 |
70 | pat = torch.cat([pat for idx in range(3)], dim=1)
71 | ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat)
72 |
73 | K = test_set.getK(sidx)
74 | Ki = np.linalg.inv(K)
75 | K = torch.from_numpy(K)
76 | Ki = torch.from_numpy(Ki)
77 | ge_loss = networks.Single_Frame_Flow_Consistency_Loss(K, Ki, imsize[0], imsize[1], clamp=0.1)
78 |
79 | self.ph_losses.append( ph_loss )
80 | self.ge_losses.append( ge_loss )
81 |
82 | d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline))
83 | self.d2ds.append( d2d )
84 |
85 | return test_sets
86 |
87 | def net_forward(self, net, flow = None):
88 | im0 = self.data['im0']
89 | tl = im0.shape[0]
90 | bs = im0.shape[1]
91 | im0 = im0.view(-1, *im0.shape[2:])
92 | out = net(im0)
93 |
94 | if not(isinstance(out, tuple) or isinstance(out, list)):
95 | out = out.view(tl, bs, *out.shape[1:])
96 | else:
97 | out = [o.view(tl, bs, *o.shape[1:]) for o in out]
98 |
99 | return out
100 |
101 | def loss_forward(self, out, train, flow_out = None):
102 | if not(isinstance(out, tuple) or isinstance(out, list)):
103 | out = [out]
104 |
105 | vals = []
106 |
107 | # apply photometric loss
108 | for s,o in zip(itertools.count(), out):
109 | im = self.data[f'im0']
110 | im = im.view(-1, *im.shape[2:])
111 | o = o.view(-1, *o.shape[2:])
112 | std = self.data[f'std0']
113 | std = std.view(-1, *std.shape[2:])
114 | val, pattern_proj = self.ph_losses[0](o, im[:,0:1,...], std)
115 | vals.append(val / (2 ** s))
116 |
117 | # apply disparity loss
118 | for s, o in zip(itertools.count(), out):
119 | if s == 0:
120 | amb0 = self.data[f'ambient0']
121 | amb0 = amb0.contiguous().view(-1, *amb0.shape[2:])
122 | o = o.view(-1, *o.shape[2:])
123 | val = self.disparity_loss(o, amb0)
124 | vals.append(val * 0.4 / (2 ** s))
125 |
126 | # apply geometric loss
127 | R = self.data['R']
128 | t = self.data['t']
129 | amb = self.data['ambient0']
130 | ge_num = self.track_length * (self.track_length-1) / 2
131 | for sidx in range(1):
132 | d2d = self.d2ds[0]
133 | depth = d2d(out[sidx])
134 | ge_loss = self.ge_losses[0]
135 | for tidx0 in range(depth.shape[0]):
136 | for tidx1 in range(tidx0+1, depth.shape[0]):
137 | depth0 = depth[tidx0]
138 | R0 = R[tidx0]
139 | t0 = t[tidx0]
140 | amb0 = amb[tidx0]
141 | flow0 = flow_out[f'flow_{tidx0}{tidx1}']
142 | depth1 = depth[tidx1]
143 | R1 = R[tidx1]
144 | t1 = t[tidx1]
145 | amb1 = amb[tidx1]
146 | flow1 = flow_out[f'flow_{tidx1}{tidx0}']
147 |
148 | val, flow_mask0, flow_mask1, orig_mask = ge_loss(depth0, depth1, R0, t0, R1, t1, flow0, flow1, amb0, amb1)
149 | vals.append(val * 0.2 / ge_num / (2 ** sidx))
150 |
151 | # using pseudo-ground truth
152 | if self.use_pseudo_gt:
153 | for s, o in zip(itertools.count(), out):
154 | val = torch.mean(torch.abs(o - self.data['pseudo_gt']))
155 | vals.append(val * 0.1 / (2 ** s))
156 |
157 | # warming up the network for a few epochs
158 | if train and self.data_type == 'real':
159 | if self.current_epoch < self.warmup_epochs:
160 | for s, o in zip(itertools.count(), out):
161 | valid_mask = (self.data['sgm_disp'] > 30).float()
162 | val = torch.sum(torch.abs(o - self.data['sgm_disp'] + 1.5 * torch.randn(o.size()).cuda()) * valid_mask) / torch.sum(valid_mask)
163 | vals.append(val * 0.1)
164 |
165 | return vals
166 |
167 | def numpy_in_out(self, output):
168 | if not(isinstance(output, tuple) or isinstance(output, list)):
169 | output = [output]
170 | es = output[0].detach().to('cpu').numpy()
171 | gt = self.data['disp0'].detach().to('cpu').numpy().astype(np.float32)
172 | im = self.data['im0'][:,:,0:1,...].detach().to('cpu').numpy()
173 | amb = self.data['ambient0'].detach().to('cpu').numpy()
174 | pat = self.patterns[0].detach().to('cpu').numpy()
175 |
176 | es = es * (gt > 0)
177 |
178 | return es, gt, im, amb, pat
179 |
180 | def write_img(self, out_path, es, gt, im, amb, pat):
181 | logging.info(f'write img {out_path}')
182 |
183 | diff = np.abs(es - gt)
184 |
185 | vmin, vmax = np.nanmin(gt), np.nanmax(gt)
186 | vmin = vmin - 0.2*(vmax-vmin)
187 | vmax = vmax + 0.2*(vmax-vmin)
188 |
189 | vmax = np.max([vmax, 16])
190 |
191 | fig = plt.figure(figsize=(16,16))
192 | es0 = co.cmap.color_depth_map(es[0], scale=vmax)
193 | gt0 = co.cmap.color_depth_map(gt[0], scale=vmax)
194 | diff0 = co.cmap.color_error_image(diff[0], BGR=True)
195 |
196 | # plot pattern and input images
197 | ax = plt.subplot(3,3,1); plt.imshow(pat, vmin=pat.min(), vmax=pat.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Projector Pattern')
198 | ax = plt.subplot(3,3,2); plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 IR Input')
199 | ax = plt.subplot(3,3,3); plt.imshow(amb[0], vmin=amb.min(), vmax=amb.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Ambient Input')
200 |
201 | # plot disparities, ground truth disparity is shown only for reference
202 | ax = plt.subplot(3,3,4); plt.imshow(gt0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity GT {np.nanmin(gt[0]):.4f}/{np.nanmax(gt[0]):.4f}')
203 | ax = plt.subplot(3,3,5); plt.imshow(es0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Est. {es[0].min():.4f}/{es[0].max():.4f}')
204 | ax = plt.subplot(3,3,6); plt.imshow(diff0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Err. {diff[0].mean():.5f}')
205 |
206 | es1 = co.cmap.color_depth_map(es[1], scale=vmax)
207 | gt1 = co.cmap.color_depth_map(gt[1], scale=vmax)
208 | diff1 = co.cmap.color_error_image(diff[1], BGR=True)
209 | ax = plt.subplot(3,3,7); plt.imshow(gt1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity GT {np.nanmin(gt[1]):.4f}/{np.nanmax(gt[1]):.4f}')
210 | ax = plt.subplot(3,3,8); plt.imshow(es1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Est. {es[1].min():.4f}/{es[1].max():.4f}')
211 | ax = plt.subplot(3,3,9); plt.imshow(diff1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Err. {diff[1].mean():.5f}')
212 |
213 | plt.tight_layout()
214 | plt.savefig(str(out_path))
215 | plt.close(fig)
216 |
217 | def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
218 | if batch_idx % 256 == 0:
219 | out_path = self.exp_output_dir / f'train_{epoch:03d}_{batch_idx:04d}.png'
220 | es, gt, im, amb, pat = self.numpy_in_out(output)
221 | self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], amb[:, 0, 0], pat[0, 0])
222 | torch.cuda.empty_cache()
223 |
224 | def callback_test_start(self, epoch, set_idx):
225 | self.metric = co.metric.MultipleMetric(
226 | co.metric.DistanceMetric(vec_length=1),
227 | co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
228 | )
229 |
230 | def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
231 | es, gt, im, amb, pat = self.numpy_in_out(output)
232 |
233 | if batch_idx % 8 == 0:
234 | out_path = self.exp_output_dir / f'test_{epoch:03d}_{batch_idx:04d}.png'
235 | self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], amb[:, 0, 0], pat[0, 0])
236 |
237 | es = self.crop_reshape(es)
238 | gt = self.crop_reshape(gt)
239 | self.metric.add(es, gt)
240 |
241 | def crop_reshape(self, input):
242 | output = input.reshape(-1, 1)
243 | return output
244 |
245 | def callback_test_stop(self, epoch, set_idx, loss):
246 | logging.info(f'{self.metric}')
247 | for k, v in self.metric.items():
248 | self.metric_add_test(epoch, set_idx, k, v)
249 |
250 | if __name__ == '__main__':
251 | pass
252 |
--------------------------------------------------------------------------------
/model/worker.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | #
5 | # MIT License
6 | #
7 | # Copyright (c) 2019 autonomousvision
8 | #
9 | # Permission is hereby granted, free of charge, to any person obtaining a copy
10 | # of this software and associated documentation files (the "Software"), to deal
11 | # in the Software without restriction, including without limitation the rights
12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 | # copies of the Software, and to permit persons to whom the Software is
14 | # furnished to do so, subject to the following conditions:
15 | #
16 | # The above copyright notice and this permission notice shall be included in all
17 | # copies or substantial portions of the Software.
18 | #
19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 | # SOFTWARE.
26 | #
27 | #
28 | # MIT License
29 | #
30 | # Copyright, 2021 ams International AG
31 | #
32 | # Permission is hereby granted, free of charge, to any person obtaining a copy
33 | # of this software and associated documentation files (the "Software"), to deal
34 | # in the Software without restriction, including without limitation the rights
35 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
36 | # copies of the Software, and to permit persons to whom the Software is
37 | # furnished to do so, subject to the following conditions:
38 | #
39 | # The above copyright notice and this permission notice shall be included in all
40 | # copies or substantial portions of the Software.
41 | #
42 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
43 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
44 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
45 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
46 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
47 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
48 | # SOFTWARE.
49 |
50 | import pickle
51 |
52 | import numpy as np
53 | import torch
54 | import os
55 | import random
56 | import logging
57 | import datetime
58 | from pathlib import Path
59 | import argparse
60 | import socket
61 | import gc
62 | import json
63 | import matplotlib.pyplot as plt
64 | import time
65 | from collections import OrderedDict
66 | from model import networks
67 |
68 |
69 | class StopWatch(object):
70 | def __init__(self):
71 | self.timings = OrderedDict()
72 | self.starts = {}
73 |
74 | def start(self, name):
75 | self.starts[name] = time.time()
76 |
77 | def stop(self, name):
78 | if name not in self.timings:
79 | self.timings[name] = []
80 | self.timings[name].append(time.time() - self.starts[name])
81 |
82 | def get(self, name=None, reduce=np.sum):
83 | if name is not None:
84 | return reduce(self.timings[name])
85 | else:
86 | ret = {}
87 | for k in self.timings:
88 | ret[k] = reduce(self.timings[k])
89 | return ret
90 |
91 | def __repr__(self):
92 | return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
93 | def __str__(self):
94 | return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
95 |
96 |
97 | class ETA(object):
98 | def __init__(self, length):
99 | self.length = length
100 | self.start_time = time.time()
101 | self.current_idx = 0
102 | self.current_time = time.time()
103 |
104 | def update(self, idx):
105 | self.current_idx = idx
106 | self.current_time = time.time()
107 |
108 | def get_elapsed_time(self):
109 | return self.current_time - self.start_time
110 |
111 | def get_item_time(self):
112 | return self.get_elapsed_time() / (self.current_idx + 1)
113 |
114 | def get_remaining_time(self):
115 | return self.get_item_time() * (self.length - self.current_idx + 1)
116 |
117 | def format_time(self, seconds):
118 | minutes, seconds = divmod(seconds, 60)
119 | hours, minutes = divmod(minutes, 60)
120 | hours = int(hours)
121 | minutes = int(minutes)
122 | return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
123 |
124 | def get_elapsed_time_str(self):
125 | return self.format_time(self.get_elapsed_time())
126 |
127 | def get_remaining_time_str(self):
128 | return self.format_time(self.get_remaining_time())
129 |
130 | class Worker(object):
131 | def __init__(self, args, seed=42, test_batch_size=4, num_workers=4, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
132 | self.use_pseudo_gt = args.use_pseudo_gt
133 | self.lcn_radius = args.lcn_radius
134 | self.track_length = args.track_length
135 | self.data_type = args.data_type
136 | # assert(self.track_length>1)
137 |
138 | self.architecture = args.architecture
139 | self.epochs = args.epochs
140 | self.warmup_epochs = args.warmup_epochs
141 | self.seed = seed
142 | self.train_batch_size = args.train_batch_size
143 | self.test_batch_size = test_batch_size
144 | self.num_workers = num_workers
145 | self.save_frequency = save_frequency
146 | self.train_device = train_device
147 | self.test_device = test_device
148 | self.max_train_iter = max_train_iter
149 |
150 | self.errs_list=[]
151 |
152 | config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'config.json'))
153 | with open(config_path) as fp:
154 | config = json.load(fp)
155 | data_root = Path(config['DATA_DIR'])
156 | output_dir = Path(config['OUTPUT_DIR'])
157 | self.settings_path = data_root / 'settings.pkl'
158 | self.output_dir = output_dir
159 | with open(str(self.settings_path), 'rb') as f:
160 | settings = pickle.load(f)
161 | self.baseline = settings['baseline']
162 | self.K = settings['K']
163 | self.Ki = np.linalg.inv(self.K)
164 | self.imsizes = [settings['imsize']]
165 | for iter in range(3):
166 | self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
167 | self.ref_pattern = settings['pattern']
168 |
169 | sample_paths = sorted((data_root).glob('0*/'))
170 | if self.data_type == 'synthetic':
171 | self.train_paths = sample_paths[2**10:]
172 | self.test_paths = sample_paths[2**9:2**10]
173 | self.valid_paths = sample_paths[0:2**9]
174 | elif self.data_type == 'real':
175 | self.test_paths = sample_paths[4::8]
176 | self.train_paths = [path for path in sample_paths if path not in self.test_paths]
177 |
178 | self.lcn_in = networks.LCN(self.lcn_radius, 0.05).cuda()
179 |
180 | self.setup_experiment()
181 |
182 | def setup_experiment(self):
183 | self.exp_output_dir = self.output_dir / self.architecture
184 | self.exp_output_dir.mkdir(parents=True, exist_ok=True)
185 |
186 | if logging.root: del logging.root.handlers[:]
187 | logging.basicConfig(
188 | level=logging.INFO,
189 | handlers=[
190 | logging.FileHandler( str(self.exp_output_dir / 'train.log') ),
191 | logging.StreamHandler()
192 | ],
193 | format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s'
194 | )
195 |
196 | logging.info('='*80)
197 | logging.info(f'Start of experiment with architecture: {self.architecture}')
198 | logging.info(socket.gethostname())
199 | self.log_datetime()
200 | logging.info('='*80)
201 |
202 | self.metric_path = self.exp_output_dir / 'metrics.json'
203 | if self.metric_path.exists():
204 | with open(str(self.metric_path), 'r') as fp:
205 | self.metric_data = json.load(fp)
206 | else:
207 | self.metric_data = {}
208 |
209 | self.init_seed()
210 |
211 | def metric_add_train(self, epoch, key, val):
212 | epoch = str(epoch)
213 | key = str(key)
214 | if epoch not in self.metric_data:
215 | self.metric_data[epoch] = {}
216 | if 'train' not in self.metric_data[epoch]:
217 | self.metric_data[epoch]['train'] = {}
218 | self.metric_data[epoch]['train'][key] = val
219 |
220 | def metric_add_test(self, epoch, set_idx, key, val):
221 | epoch = str(epoch)
222 | set_idx = str(set_idx)
223 | key = str(key)
224 | if epoch not in self.metric_data:
225 | self.metric_data[epoch] = {}
226 | if 'test' not in self.metric_data[epoch]:
227 | self.metric_data[epoch]['test'] = {}
228 | if set_idx not in self.metric_data[epoch]['test']:
229 | self.metric_data[epoch]['test'][set_idx] = {}
230 | self.metric_data[epoch]['test'][set_idx][key] = val
231 |
232 | def metric_save(self):
233 | with open(str(self.metric_path), 'w') as fp:
234 | json.dump(self.metric_data, fp, indent=2)
235 |
236 | def init_seed(self, seed=None):
237 | if seed is not None:
238 | self.seed = seed
239 | logging.info(f'Set seed to {self.seed}')
240 | np.random.seed(self.seed)
241 | random.seed(self.seed)
242 | torch.manual_seed(self.seed)
243 | torch.cuda.manual_seed(self.seed)
244 |
245 | def log_datetime(self):
246 | logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
247 |
248 | def mem_report(self):
249 | for obj in gc.get_objects():
250 | if torch.is_tensor(obj):
251 | print(type(obj), obj.shape)
252 |
253 | def get_net_path(self, epoch, root=None):
254 | if root is None:
255 | root = self.exp_output_dir
256 | return root / f'net_{epoch:04d}.params'
257 |
258 | def get_do_parser_cmds(self):
259 | return ['retrain', 'resume', 'retest', 'test_init']
260 |
261 | def get_do_parser(self):
262 | parser = argparse.ArgumentParser()
263 | parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds())
264 | parser.add_argument('--epoch', type=int, default=-1)
265 | return parser
266 |
267 | def do_cmd(self, args, net, optimizer, scheduler=None):
268 | if args.cmd == 'retrain':
269 | self.train(net, optimizer, resume=False, scheduler=scheduler)
270 | elif args.cmd == 'resume':
271 | self.train(net, optimizer, resume=True, scheduler=scheduler)
272 | elif args.cmd == 'retest':
273 | self.retest(net, epoch=args.epoch)
274 | elif args.cmd == 'test_init':
275 | test_sets = self.get_test_sets()
276 | self.test(-1, net, test_sets)
277 | else:
278 | raise Exception('invalid cmd')
279 |
280 | def do(self, net, optimizer, load_net_optimizer=None, scheduler=None):
281 | parser = self.get_do_parser()
282 | args, _ = parser.parse_known_args()
283 |
284 | if load_net_optimizer is not None and args.cmd not in ['schedule']:
285 | net, optimizer = load_net_optimizer()
286 |
287 | self.do_cmd(args, net, optimizer, scheduler=scheduler)
288 |
289 | def retest(self, net, epoch=-1):
290 | if epoch < 0:
291 | epochs = range(self.epochs)
292 | else:
293 | epochs = [epoch]
294 |
295 | test_sets = self.get_test_sets()
296 |
297 | for epoch in epochs:
298 | net_path = self.get_net_path(epoch)
299 | if net_path.exists():
300 | state_dict = torch.load(str(net_path))
301 | net.load_state_dict(state_dict)
302 | self.test(epoch, net, test_sets)
303 |
304 | def format_err_str(self, errs, div=1):
305 | err = sum(errs)
306 | if len(errs) > 1:
307 | err_str = f'{err/div:0.4f}=' + '+'.join([f'{e/div:0.4f}' for e in errs])
308 | else:
309 | err_str = f'{err/div:0.4f}'
310 | return err_str
311 |
312 | def write_err_img(self):
313 | err_img_path = self.exp_output_dir / 'errs.png'
314 | fig = plt.figure(figsize=(16,16))
315 | lines=[]
316 | for idx,errs in enumerate(self.errs_list):
317 | line,=plt.plot(range(len(errs)), errs, label=f'error{idx}')
318 | lines.append(line)
319 | plt.tight_layout()
320 | plt.legend(handles=lines)
321 | plt.savefig(str(err_img_path))
322 | plt.close(fig)
323 |
324 |
325 | def callback_train_new_epoch(self, epoch, net, optimizer):
326 | pass
327 |
328 | def train(self, net, optimizer, resume=False, scheduler=None):
329 | logging.info('='*80)
330 | logging.info('Start training')
331 | self.log_datetime()
332 | logging.info('='*80)
333 |
334 | train_set = self.get_train_set()
335 | test_sets = self.get_test_sets()
336 |
337 | net = net.to(self.train_device)
338 |
339 | epoch = 0
340 | min_err = {ts.name: 1e9 for ts in test_sets}
341 |
342 | state_path = self.exp_output_dir / 'state.dict'
343 | if resume and state_path.exists():
344 | logging.info('='*80)
345 | logging.info(f'Loading state from {state_path}')
346 | logging.info('='*80)
347 | state = torch.load(str(state_path))
348 | epoch = state['epoch'] + 1
349 | if 'min_err' in state:
350 | min_err = state['min_err']
351 |
352 | curr_state = net.state_dict()
353 | curr_state.update(state['state_dict'])
354 | net.load_state_dict(curr_state)
355 |
356 | try:
357 | optimizer.load_state_dict(state['optimizer'])
358 | except:
359 | logging.info('Warning: cannot load optimizer from state_dict')
360 | pass
361 | if 'cpu_rng_state' in state:
362 | torch.set_rng_state(state['cpu_rng_state'])
363 | if 'gpu_rng_state' in state:
364 | torch.cuda.set_rng_state(state['gpu_rng_state'])
365 |
366 | for epoch in range(epoch, self.epochs):
367 | self.current_epoch = epoch
368 | self.callback_train_new_epoch(epoch, net, optimizer)
369 |
370 | # train epoch
371 | self.train_epoch(epoch, net, optimizer, train_set)
372 |
373 | # test epoch
374 | errs = self.test(epoch, net, test_sets)
375 |
376 | if (epoch + 1) % self.save_frequency == 0:
377 | net = net.to(self.train_device)
378 |
379 | state_dict = {
380 | 'epoch': epoch,
381 | 'min_err': min_err,
382 | 'state_dict': net.state_dict(),
383 | 'optimizer': optimizer.state_dict(),
384 | 'cpu_rng_state': torch.get_rng_state(),
385 | 'gpu_rng_state': torch.cuda.get_rng_state(),
386 | }
387 | logging.info(f'save state to {state_path}')
388 | state_path = self.exp_output_dir / 'state.dict'
389 | torch.save(state_dict, str(state_path))
390 |
391 | for test_set_name in errs:
392 | err = sum(errs[test_set_name])
393 | if err < min_err[test_set_name]:
394 | min_err[test_set_name] = err
395 | state_path = self.exp_output_dir / f'state_set_{test_set_name}_best.dict'
396 | logging.info(f'save state to {state_path}')
397 | torch.save(state_dict, str(state_path))
398 |
399 | # store network
400 | net_path = self.get_net_path(epoch)
401 | logging.info(f'save network to {net_path}')
402 | torch.save(net.state_dict(), str(net_path))
403 |
404 | if scheduler is not None:
405 | scheduler.step()
406 |
407 | logging.info('='*80)
408 | logging.info('Finished training')
409 | self.log_datetime()
410 | logging.info('='*80)
411 |
412 | def get_train_set(self):
413 | raise NotImplementedError()
414 |
415 | def get_test_sets(self):
416 | raise NotImplementedError()
417 |
418 | def copy_data(self, data, device, requires_grad, train):
419 | self.data = {}
420 |
421 | self.lcn_in = self.lcn_in.to(device)
422 | for key, val in data.items():
423 | # from
424 | # batch_size x track_length x ...
425 | # to
426 | # track_length x batch_size x ...
427 | if len(val.shape)>2:
428 | val = val.transpose(0, 1)
429 | self.data[key] = val.to(device)
430 | if 'im' in key and 'blend' not in key and 'primary' not in key:
431 | im = self.data[key]
432 | tl = im.shape[0]
433 | bs = im.shape[1]
434 | im_lcn,im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:]))
435 | key_std = key.replace('im','std')
436 | self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device)
437 | im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2)
438 | self.data[key] = im_cat
439 | elif key == 'ambient0':
440 | ambient = self.data[key]
441 | tl = ambient.shape[0]
442 | bs = ambient.shape[1]
443 | ambient_lcn, ambient_std = self.lcn_in(ambient.contiguous().view(-1, *ambient.shape[2:]))
444 | ambient_cat = torch.cat((ambient_lcn.view(tl, bs, *ambient.shape[2:]), ambient), dim=2)
445 | self.data[f'{key}_in'] = ambient_cat.to(device).requires_grad_(requires_grad=requires_grad)
446 |
447 | # Mimicing the reference pattern
448 | pat = self.ref_pattern.mean(axis=2)
449 | pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda')
450 | pat_lcn, _ = self.lcn_in(pat)
451 | pat_cat = torch.cat((pat_lcn, pat), dim=1).unsqueeze(0)
452 | self.data[f'ref_pattern'] = pat_cat.repeat([*self.data['im0'].shape[0:2], 1, 1, 1])
453 |
454 | def net_forward(self, net, train):
455 | raise NotImplementedError()
456 |
457 | def read_optical_flow(self, train):
458 | im = self.data['ambient0']
459 | out = {}
460 | for tidx0 in range(im.shape[0]):
461 | for tidx1 in range(im.shape[0]):
462 | if tidx0 != tidx1:
463 | out[f'flow_{tidx0}{tidx1}'] = self.data[f'flow_{tidx0}{tidx1}'][0]
464 |
465 | return out
466 |
467 | def loss_forward(self, output, train, flow_out):
468 | raise NotImplementedError()
469 |
470 | def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
471 | pass
472 |
473 | def callback_train_start(self, epoch):
474 | pass
475 |
476 | def callback_train_stop(self, epoch, loss):
477 | pass
478 |
479 | def train_epoch(self, epoch, net, optimizer, dset):
480 | self.callback_train_start(epoch)
481 | stopwatch = StopWatch()
482 |
483 | logging.info('='*80)
484 | logging.info('Train epoch %d' % epoch)
485 |
486 | dset.current_epoch = epoch
487 | train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=False)
488 |
489 | net = net.to(self.train_device)
490 | net.train()
491 |
492 | mean_loss = None
493 |
494 | n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader)
495 | bar = ETA(length=n_batches)
496 |
497 | stopwatch.start('total')
498 | stopwatch.start('data')
499 | for batch_idx, data in enumerate(train_loader):
500 | if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break
501 | self.copy_data(data, device=self.train_device, requires_grad=False, train=True)
502 | stopwatch.stop('data')
503 |
504 | optimizer.zero_grad()
505 |
506 | stopwatch.start('forward')
507 | flow_output = self.read_optical_flow(train=True)
508 | output = self.net_forward(net, flow_output)
509 |
510 | if 'cuda' in self.train_device: torch.cuda.synchronize()
511 | stopwatch.stop('forward')
512 |
513 | stopwatch.start('loss')
514 | errs = self.loss_forward(output, True, flow_output)
515 | if isinstance(errs, dict):
516 | masks = errs['masks']
517 | errs = errs['errs']
518 | else:
519 | masks = []
520 | if not isinstance(errs, list) and not isinstance(errs, tuple):
521 | errs = [errs]
522 | err = sum(errs)
523 | if 'cuda' in self.train_device: torch.cuda.synchronize()
524 | stopwatch.stop('loss')
525 |
526 | stopwatch.start('backward')
527 | err.backward()
528 | self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks)
529 | if 'cuda' in self.train_device: torch.cuda.synchronize()
530 | stopwatch.stop('backward')
531 |
532 | # print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), 'GB')
533 | # print('Max Allocated:', round(torch.cuda.max_memory_allocated(0) / 1024 ** 3, 1), 'GB')
534 | # print('Max Cached:', round(torch.cuda.max_memory_cached(0) / 1024 ** 3, 1), 'GB')
535 |
536 | stopwatch.start('optimizer')
537 | optimizer.step()
538 | if 'cuda' in self.train_device: torch.cuda.synchronize()
539 | stopwatch.stop('optimizer')
540 |
541 | bar.update(batch_idx)
542 | if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
543 | err_str = self.format_err_str(errs)
544 | logging.info(f'train e{epoch}: {batch_idx+1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
545 | #self.write_err_img()
546 |
547 |
548 | if mean_loss is None:
549 | mean_loss = [0 for e in errs]
550 | for erridx, err in enumerate(errs):
551 | mean_loss[erridx] += err.item()
552 |
553 | stopwatch.start('data')
554 | stopwatch.stop('total')
555 | logging.info('timings: %s' % stopwatch)
556 |
557 | mean_loss = [l / len(train_loader) for l in mean_loss]
558 | self.callback_train_stop(epoch, mean_loss)
559 | self.metric_add_train(epoch, 'loss', mean_loss)
560 |
561 | # save metrics
562 | self.metric_save()
563 |
564 | err_str = self.format_err_str(mean_loss)
565 | logging.info(f'avg train_loss={err_str}')
566 | return mean_loss
567 |
568 | def callback_test_start(self, epoch, set_idx):
569 | pass
570 |
571 | def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
572 | pass
573 |
574 | def callback_test_stop(self, epoch, set_idx, loss):
575 | pass
576 |
577 | def test(self, epoch, net, test_sets):
578 | errs = {}
579 | for test_set_idx, test_set in enumerate(test_sets):
580 | if (epoch + 1) % test_set.test_frequency == 0:
581 | logging.info('='*80)
582 | logging.info(f'testing set {test_set.name}')
583 | err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
584 | errs[test_set.name] = err
585 | return errs
586 |
587 | def test_epoch(self, epoch, set_idx, net, dset):
588 | logging.info('-'*80)
589 | logging.info('Test epoch %d' % epoch)
590 | dset.current_epoch = epoch
591 | test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False, pin_memory=False)
592 |
593 | net = net.to(self.test_device)
594 | net.eval()
595 |
596 | with torch.no_grad():
597 | mean_loss = None
598 |
599 | self.callback_test_start(epoch, set_idx)
600 |
601 | bar = ETA(length=len(test_loader))
602 | stopwatch = StopWatch()
603 | stopwatch.start('total')
604 | stopwatch.start('data')
605 | for batch_idx, data in enumerate(test_loader):
606 | self.copy_data(data, device=self.test_device, requires_grad=False, train=False)
607 | stopwatch.stop('data')
608 |
609 | stopwatch.start('forward')
610 | flow_output = self.read_optical_flow(train=False)
611 |
612 | output = self.net_forward(net, flow_output)
613 |
614 | if 'cuda' in self.test_device: torch.cuda.synchronize()
615 | stopwatch.stop('forward')
616 |
617 | stopwatch.start('loss')
618 | errs = self.loss_forward(output, False, flow_output)
619 | if isinstance(errs, dict):
620 | masks = errs['masks']
621 | errs = errs['errs']
622 | else:
623 | masks = []
624 | if not isinstance(errs, list) and not isinstance(errs, tuple):
625 | errs = [errs]
626 |
627 | bar.update(batch_idx)
628 | if batch_idx % 25 == 0:
629 | err_str = self.format_err_str(errs)
630 | logging.info(f'test e{epoch}: {batch_idx+1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
631 |
632 | if mean_loss is None:
633 | mean_loss = [0 for e in errs]
634 | for erridx, err in enumerate(errs):
635 | mean_loss[erridx] += err.item()
636 | stopwatch.stop('loss')
637 |
638 | self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks)
639 |
640 | stopwatch.start('data')
641 | stopwatch.stop('total')
642 | logging.info('timings: %s' % stopwatch)
643 |
644 | mean_loss = [l / len(test_loader) for l in mean_loss]
645 | self.callback_test_stop(epoch, set_idx, mean_loss)
646 | self.metric_add_test(epoch, set_idx, 'loss', mean_loss)
647 |
648 | # save metrics
649 | self.metric_save()
650 |
651 | err_str = self.format_err_str(mean_loss)
652 | logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
653 | return mean_loss
654 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch
2 | torchvision
3 | cython
4 | numpy
5 | matplotlib
6 | pandas
7 | scipy
8 | opencv
9 | h5py
10 | tqdm
11 | cupy
12 | ipdb
--------------------------------------------------------------------------------
/train_val.py:
--------------------------------------------------------------------------------
1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps
2 | # from active structured-light sensor's multiple video frames.
3 | #
4 | # MIT License
5 | #
6 | # Copyright, 2021 ams International AG
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 | import torch
27 | from model import single_frame_worker
28 | from model import multi_frame_worker
29 | from model import networks
30 | from model import multi_frame_networks
31 | from co.args import parse_args
32 |
33 | torch.backends.cudnn.benchmark = True
34 |
35 | # parse args
36 | args = parse_args()
37 |
38 | # loss types
39 | if args.architecture == 'single_frame':
40 | worker = single_frame_worker.Worker(args)
41 | elif args.architecture == 'multi_frame':
42 | worker = multi_frame_worker.Worker(args)
43 |
44 | if args.use_pseudo_gt and args.architecture != 'single_frame':
45 | print("Using pseudo-gt is only possible in single-frame architecture")
46 | raise NotImplementedError
47 |
48 | # set up network
49 | if args.architecture == 'single_frame':
50 | net = networks.DispDecoder(channels_in=2, max_disp=args.max_disp, imsizes=worker.imsizes)
51 | elif args.architecture == 'multi_frame':
52 | net = multi_frame_networks.FuseNet(imsize=worker.imsizes[0], K=worker.K, baseline=worker.baseline, track_length=worker.track_length, max_disp=args.max_disp)
53 |
54 | # optimizer
55 | opt_parameters = net.parameters()
56 | optimizer = torch.optim.Adam(opt_parameters, lr=1e-4)
57 |
58 | # start the work
59 | worker.do(net, optimizer)
--------------------------------------------------------------------------------