├── LICENSE ├── README.md ├── co ├── __init__.py ├── args.py ├── cmap.py ├── geometry.py ├── gtimer.py ├── io3d.py ├── metric.py └── utils.py ├── config.json ├── data ├── __init__.py ├── base_dataset.py ├── create_syn_data.py ├── data_manipulation.py ├── dataset.py ├── default_pattern.png ├── kinect_pattern.png ├── presave_disp.py ├── presave_optical_flow_data.py └── real_pattern.png ├── model ├── __init__.py ├── ext_functions.py ├── multi_frame_networks.py ├── multi_frame_worker.py ├── networks.py ├── single_frame_worker.py └── worker.py ├── requirements.txt └── train_val.py /LICENSE: -------------------------------------------------------------------------------- 1 | LICENSE FOR ORIGINAL FILE / https://github.com/autonomousvision/connecting_the_dots/blob/master/LICENSE 2 | 3 | MIT License 4 | 5 | Copyright (c) 2019 autonomousvision 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | 25 | 26 | 27 | LICENSE FOR MODIFICATIONS OF ORIGINAL FILE AND FOR OTHER FILES IN THIS REPOSITORY 28 | 29 | MIT LICENSE 30 | 31 | Copyright 2021, ams International AG 32 | 33 | Permission is hereby granted, free of charge, to any person obtaining a copy 34 | of this software and associated documentation files (the "Software"), to deal 35 | in the Software without restriction, including without limitation the rights 36 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 37 | copies of the Software, and to permit persons to whom the Software is 38 | furnished to do so, subject to the following conditions: 39 | 40 | The above copyright notice and this permission notice shall be included in all 41 | copies or substantial portions of the Software. 42 | 43 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 44 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 45 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 46 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 47 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 48 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 49 | SOFTWARE. 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | > # [ICCV 2021] DepthInSpace: Exploitation and Fusion of Multiple Frames of a Video for Structured-Light Depth Estimation
3 | > Mohammad Mahdi Johari, Camilla Carta, François Fleuret
4 | > [Project Page](https://www.idiap.ch/paper/depthinspace/) | [Paper](https://openaccess.thecvf.com/content/ICCV2021/html/Johari_DepthInSpace_Exploitation_and_Fusion_of_Multiple_Video_Frames_for_Structured-Light_ICCV_2021_paper.html) 5 | 6 | ## Dependencies 7 | 8 | The network training/evaluation code is based on `PyTorch` and is tested in the following environment: 9 | ``` 10 | Python==3.8.6 11 | PyTorch==1.7.0 12 | CUDA==10.1 13 | ``` 14 | 15 | All required packages can be installed with `anaconda`: 16 | ``` 17 | conda install --file requirements.txt -c pytorch -c conda-forge 18 | ``` 19 | 20 | ### External Libraries 21 | To train and evaluate our method on synthetic datasets, we use the structured light renderer provided by [Connecting the Dots](https://github.com/autonomousvision/connecting_the_dots). 22 | It can be used to render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable projector location. 23 | Furthermore, Our models use some custom layers provided in [Connecting the Dots](https://github.com/autonomousvision/connecting_the_dots). 24 | First, download [ShapeNet V2](https://www.shapenet.org/) and correct `SHAPENET_DIR` in `config.json` accordingly. 25 | Then, to install these dependencies, use the following instructions and set `CTD_DIR` in `config.json` to the path of the cloned [Connecting the Dots](https://github.com/autonomousvision/connecting_the_dots) repository: 26 | 27 | ``` 28 | git clone https://github.com/autonomousvision/connecting_the_dots.git 29 | cd connecting_the_dots 30 | cd renderer 31 | make 32 | cd .. 33 | cd data/lcn 34 | python setup.py build_ext --inplace 35 | cd ../.. 36 | cd torchext 37 | python setup.py build_ext --inplace 38 | cd .. 39 | ``` 40 | 41 | As a preprocessing step, you need to execute [LiteFlowNet](https://github.com/sniklaus/pytorch-liteflownet) software on the data before running our models. To this end, clone [our forked copy of pytorch-liteflownet](https://github.com/MohammadJohari/pytorch-liteflownet) with the following command and set `LITEFLOWNET_DIR` in `config.json` accordingly. 42 | ``` 43 | git clone https://github.com/MohammadJohari/pytorch-liteflownet.git 44 | ``` 45 | Make sure you comply with the [license](https://github.com/twhui/LiteFlowNet#license-and-citation) terms of LiteFlowNet's original paper. 46 | However, DepthInSpace models are compatible with any external optical flow library. 47 | To use DepthInSpace with other optical flow libraries, you need to modify the code in `data/presave_optical_flow_data.py` and make it compatible with your optical flow model of choice. 48 | 49 | ## Running 50 | 51 | 52 | ### Creating Synthetic Data 53 | The synthetic data will be generated and saved to `DATA_DIR` in `config.json`. 54 | In order to generate the data with the `default` projection dot pattern, change directory to `data` and run 55 | 56 | ``` 57 | python create_syn_data.py default 58 | ``` 59 | 60 | Other available options for the dot pattern are: `kinect` and `real`, where `real` is the real observed dot pattern in our experiments. 61 | 62 | ### Pre-Saving Optical Flow Data 63 | Before training, optical flow predictions from LiteFlowNet should be pre-saved. To do so, make sure the `DATA_DIR` in `config.json` is correct and run the following command in `data` directory 64 | 65 | ``` 66 | python presave_optical_flow_data.py 67 | ``` 68 | 69 | ### Training DIS-SF 70 | Note that the weights and state of training of our networks are saved in `OUTPUT_DIR` in `config.json`. 71 | For training the DIS-SF network with an arbitrary batch size (e.g. 8), you can run 72 | 73 | ``` 74 | python train_val.py --train_batch_size 8 --architecture single_frame 75 | ``` 76 | 77 | ### Training DIS-MF 78 | before training the DIS-MF network, the DIS-SF network must have been trained and its outputs must have been pre-saved. 79 | Make sure the `DATA_DIR` in `config.json` is correct and the trained weights of the DIS-SF network are available in `OUTPUT_DIR` in `config.json`. 80 | Then, you can pre-save the outputs of an specific epoch (e.g. 100) of the DIS-SF network by running the following command in `data` directory 81 | 82 | ``` 83 | python presave_disp.py single_frame --epoch 100 84 | ``` 85 | 86 | You can then train the DIS-MF network with an arbitrary batch size (e.g. 4) by running 87 | 88 | ``` 89 | python train_val.py --train_batch_size 4 --architecture multi_frame 90 | ``` 91 | The DIS-MF network can be trained with batch size of 4 on a device with 24 Gigabytes of GPU memory. 92 | 93 | ### Training DIS-FTSF 94 | before training the DIS-FTSF network, the DIS-MF network must have been trained and its outputs must have been pre-saved. 95 | Make sure the `DATA_DIR` in `config.json` is correct and the trained weights of the DIS-MF network are available in `OUTPUT_DIR` in `config.json`. 96 | Then, you can pre-save the outputs of an specific epoch (e.g. 50) of the DIS-MF network by running the following command in `data` directory 97 | 98 | ``` 99 | python presave_disp.py multi_frame --epoch 50 100 | ``` 101 | 102 | You can then train the DIS-FTSF network with an arbitrary batch size (e.g. 8) by running 103 | 104 | ``` 105 | python train_val.py --train_batch_size 8 --architecture single_frame --use_pseudo_gt True 106 | ``` 107 | 108 | ### Evaluating the Networks 109 | To evaluate a specific checkpoint of a specific network, e.g. the 50th epoch of the DIS-MF network, one can run 110 | ``` 111 | python train_val.py --architecture multi_frame --cmd retest --epoch 50 112 | ``` 113 | ### Training/Evaluating on Real Dataset 114 | Our captured real dataset can be downloaded from [here](https://www.idiap.ch/en/dataset/depthinspace/index_html). Make sure to use `--data_type real` when you want to train or evaluate the models on our real dataset to use the same train/test split as in our paper. 115 | 116 | ### Pretrained Networks 117 | The pretrained networks for synthetic datasets with different projection patterns and the real dataset can be found [here](https://drive.google.com/drive/folders/1uiSElbiQhXxag2VIpXy4lOoueoAEh1ak?usp=sharing). In order to use these networks, make sure `OUTPUT_DIR` in `config.json` corresponds to the proper pretrained directory. Then, use the following for the single-frame model: 118 | ``` 119 | python train_val.py --architecture single_frame --cmd retest --epoch 0 120 | ``` 121 | and the following for the multi-frame model: 122 | ``` 123 | python train_val.py --architecture multi_frame --cmd retest --epoch 0 124 | ``` 125 | Make sure to add `--data_type real` when you want to use the models on our real dataset to use the same train/test split as in our paper. 126 | 127 | ### Contact 128 | You can contact the author through email: mohammad.johari At idiap.ch. 129 | 130 | ### License 131 | All codes found in this repository are licensed under the [MIT License](LICENSE). 132 | 133 | ## Citing 134 | If you find our work useful, please consider citing: 135 | ```BibTeX 136 | @inproceedings{johari-et-al-2021, 137 | author = {Johari, M. and Carta, C. and Fleuret, F.}, 138 | title = {DepthInSpace: Exploitation and Fusion of Multiple Video Frames for Structured-Light Depth Estimation}, 139 | booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)}, 140 | year = {2021} 141 | } 142 | ``` 143 | 144 | ### Acknowledgement 145 | This work was supported by ams OSRAM. -------------------------------------------------------------------------------- /co/__init__.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # MIT License 5 | # 6 | # Copyright (c) 2019 autonomousvision 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | 27 | # set matplotlib backend depending on env 28 | import os 29 | import matplotlib 30 | if os.name == 'posix' and "DISPLAY" not in os.environ: 31 | matplotlib.use('Agg') 32 | 33 | from . import geometry 34 | from . import metric 35 | from . import utils 36 | from . import io3d 37 | from . import gtimer 38 | from . import cmap 39 | from . import args 40 | -------------------------------------------------------------------------------- /co/args.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # MIT License 5 | # 6 | # Copyright (c) 2019 autonomousvision 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | import argparse 27 | from .utils import str2bool 28 | 29 | 30 | def parse_args(): 31 | parser = argparse.ArgumentParser() 32 | # 33 | parser.add_argument('--data_type', 34 | default='synthetic', choices=['synthetic', 'real'], type=str) 35 | # 36 | parser.add_argument('--cmd', 37 | help='Start training or test', 38 | default='resume', choices=['retrain', 'resume', 'retest', 'test_init'], type=str) 39 | parser.add_argument('--epoch', 40 | help='If larger than -1, retest on the specified epoch', 41 | default=-1, type=int) 42 | parser.add_argument('--epochs', 43 | help='Training epochs', 44 | default=100, type=int) 45 | parser.add_argument('--warmup_epochs', 46 | help='Number of epochs where SGM Disparities are used as supervisor when training on the real dataset', 47 | default=150, type=int) 48 | # 49 | parser.add_argument('--lcn_radius', 50 | help='Radius of the window for LCN pre-processing', 51 | default=5, type=int) 52 | parser.add_argument('--max_disp', 53 | help='Maximum disparity', 54 | default=128, type=int) 55 | # 56 | parser.add_argument('--track_length', 57 | help='Track length for geometric loss', 58 | default=4, type=int) 59 | # 60 | parser.add_argument('--train_batch_size', 61 | help='Train Batch Size', 62 | default=8, type=int) 63 | # 64 | parser.add_argument('--architecture', 65 | help='The architecture which will be used', 66 | default='single_frame', choices=['single_frame', 'multi_frame'], type=str) 67 | # 68 | parser.add_argument('--use_pseudo_gt', 69 | help='Only applicable in single-frame model', 70 | default=False, type=str2bool) 71 | 72 | args = parser.parse_args() 73 | 74 | return args -------------------------------------------------------------------------------- /co/cmap.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # MIT License 5 | # 6 | # Copyright (c) 2019 autonomousvision 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | import numpy as np 27 | 28 | _color_map_errors = np.array([ 29 | [149, 54, 49], # 0: log2(x) = -infinity 30 | [180, 117, 69], # 0.0625: log2(x) = -4 31 | [209, 173, 116], # 0.125: log2(x) = -3 32 | [233, 217, 171], # 0.25: log2(x) = -2 33 | [248, 243, 224], # 0.5: log2(x) = -1 34 | [144, 224, 254], # 1.0: log2(x) = 0 35 | [97, 174, 253], # 2.0: log2(x) = 1 36 | [67, 109, 244], # 4.0: log2(x) = 2 37 | [39, 48, 215], # 8.0: log2(x) = 3 38 | [38, 0, 165], # 16.0: log2(x) = 4 39 | [38, 0, 165] # inf: log2(x) = inf 40 | ]).astype(float) 41 | 42 | 43 | def color_error_image(errors, scale=1.2, log_scale=0.25, mask=None, BGR=True): 44 | """ 45 | Color an input error map. 46 | 47 | Arguments: 48 | errors -- HxW numpy array of errors 49 | [scale=1] -- scaling the error map (color change at unit error) 50 | [mask=None] -- zero-pixels are masked white in the result 51 | [BGR=True] -- toggle between BGR and RGB 52 | 53 | Returns: 54 | colored_errors -- HxWx3 numpy array visualizing the errors 55 | """ 56 | 57 | errors_flat = errors.flatten() 58 | errors_color_indices = np.clip(np.log2(errors_flat / scale + 1e-5) / log_scale + 5, 0, 9) 59 | i0 = np.floor(errors_color_indices).astype(int) 60 | f1 = errors_color_indices - i0.astype(float) 61 | colored_errors_flat = _color_map_errors[i0, :] * (1 - f1).reshape(-1, 1) + _color_map_errors[i0 + 1, 62 | :] * f1.reshape(-1, 1) 63 | 64 | if mask is not None: 65 | colored_errors_flat[mask.flatten() == 0] = 255 66 | 67 | if not BGR: 68 | colored_errors_flat = colored_errors_flat[:, [2, 1, 0]] 69 | 70 | return colored_errors_flat.reshape(errors.shape[0], errors.shape[1], 3).astype(np.int) 71 | 72 | 73 | _color_map_depths = np.array([ 74 | [0, 0, 0], # 0.000 75 | [0, 0, 255], # 0.114 76 | [255, 0, 0], # 0.299 77 | [255, 0, 255], # 0.413 78 | [0, 255, 0], # 0.587 79 | [0, 255, 255], # 0.701 80 | [255, 255, 0], # 0.886 81 | [255, 255, 255], # 1.000 82 | [255, 255, 255], # 1.000 83 | ]).astype(float) 84 | _color_map_bincenters = np.array([ 85 | 0.0, 86 | 0.114, 87 | 0.299, 88 | 0.413, 89 | 0.587, 90 | 0.701, 91 | 0.886, 92 | 1.000, 93 | 2.000, # doesn't make a difference, just strictly higher than 1 94 | ]) 95 | 96 | 97 | def color_depth_map(depths, scale=None): 98 | """ 99 | Color an input depth map. 100 | 101 | Arguments: 102 | depths -- HxW numpy array of depths 103 | [scale=None] -- scaling the values (defaults to the maximum depth) 104 | 105 | Returns: 106 | colored_depths -- HxWx3 numpy array visualizing the depths 107 | """ 108 | 109 | if scale is None: 110 | scale = depths.max() 111 | 112 | values = np.clip(depths.flatten() / scale, 0, 1) 113 | # for each value, figure out where they fit in in the bincenters: what is the last bincenter smaller than this value? 114 | lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1, -1)) * np.arange(0, 9)).max(axis=1) 115 | lower_bin_value = _color_map_bincenters[lower_bin] 116 | higher_bin_value = _color_map_bincenters[lower_bin + 1] 117 | alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value) 118 | colors = _color_map_depths[lower_bin] * (1 - alphas).reshape(-1, 1) + _color_map_depths[ 119 | lower_bin + 1] * alphas.reshape(-1, 1) 120 | return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8) 121 | 122 | # from utils.debug import save_color_numpy 123 | # save_color_numpy(color_depth_map(np.matmul(np.ones((100,1)), np.arange(0,1200).reshape(1,1200)), scale=1000)) 124 | -------------------------------------------------------------------------------- /co/geometry.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # MIT License 5 | # 6 | # Copyright (c) 2019 autonomousvision 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | import numpy as np 27 | 28 | def nullspace(A, atol=1e-13, rtol=0): 29 | u, s, vh = np.linalg.svd(A) 30 | tol = max(atol, rtol * s[0]) 31 | nnz = (s >= tol).sum() 32 | ns = vh[nnz:].conj().T 33 | return ns 34 | 35 | def nearest_orthogonal_matrix(R): 36 | U,S,Vt = np.linalg.svd(R) 37 | return U @ np.eye(3,dtype=R.dtype) @ Vt 38 | 39 | def power_iters(A, n_iters=10): 40 | b = np.random.uniform(-1,1, size=(A.shape[0], A.shape[1], 1)) 41 | for iter in range(n_iters): 42 | b = A @ b 43 | b = b / np.linalg.norm(b, axis=1, keepdims=True) 44 | return b 45 | 46 | def rayleigh_quotient(A, b): 47 | return (b.transpose(0,2,1) @ A @ b) / (b.transpose(0,2,1) @ b) 48 | 49 | 50 | def cross_prod_mat(x): 51 | x = x.reshape(-1,3) 52 | X = np.empty((x.shape[0],3,3), dtype=x.dtype) 53 | X[:,0,0] = 0 54 | X[:,0,1] = -x[:,2] 55 | X[:,0,2] = x[:,1] 56 | X[:,1,0] = x[:,2] 57 | X[:,1,1] = 0 58 | X[:,1,2] = -x[:,0] 59 | X[:,2,0] = -x[:,1] 60 | X[:,2,1] = x[:,0] 61 | X[:,2,2] = 0 62 | return X.squeeze() 63 | 64 | def hat_operator(x): 65 | return cross_prod_mat(x) 66 | 67 | def vee_operator(X): 68 | X = X.reshape(-1,3,3) 69 | x = np.empty((X.shape[0], 3), dtype=X.dtype) 70 | x[:,0] = X[:,2,1] 71 | x[:,1] = X[:,0,2] 72 | x[:,2] = X[:,1,0] 73 | return x.squeeze() 74 | 75 | 76 | def rot_x(x, dtype=np.float32): 77 | x = np.array(x, copy=False) 78 | x = x.reshape(-1,1) 79 | R = np.zeros((x.shape[0],3,3), dtype=dtype) 80 | R[:,0,0] = 1 81 | R[:,1,1] = np.cos(x).ravel() 82 | R[:,1,2] = -np.sin(x).ravel() 83 | R[:,2,1] = np.sin(x).ravel() 84 | R[:,2,2] = np.cos(x).ravel() 85 | return R.squeeze() 86 | 87 | def rot_y(y, dtype=np.float32): 88 | y = np.array(y, copy=False) 89 | y = y.reshape(-1,1) 90 | R = np.zeros((y.shape[0],3,3), dtype=dtype) 91 | R[:,0,0] = np.cos(y).ravel() 92 | R[:,0,2] = np.sin(y).ravel() 93 | R[:,1,1] = 1 94 | R[:,2,0] = -np.sin(y).ravel() 95 | R[:,2,2] = np.cos(y).ravel() 96 | return R.squeeze() 97 | 98 | def rot_z(z, dtype=np.float32): 99 | z = np.array(z, copy=False) 100 | z = z.reshape(-1,1) 101 | R = np.zeros((z.shape[0],3,3), dtype=dtype) 102 | R[:,0,0] = np.cos(z).ravel() 103 | R[:,0,1] = -np.sin(z).ravel() 104 | R[:,1,0] = np.sin(z).ravel() 105 | R[:,1,1] = np.cos(z).ravel() 106 | R[:,2,2] = 1 107 | return R.squeeze() 108 | 109 | def xyz_from_rotm(R): 110 | R = R.reshape(-1,3,3) 111 | xyz = np.empty((R.shape[0],3), dtype=R.dtype) 112 | for bidx in range(R.shape[0]): 113 | if R[bidx,0,2] < 1: 114 | if R[bidx,0,2] > -1: 115 | xyz[bidx,1] = np.arcsin(R[bidx,0,2]) 116 | xyz[bidx,0] = np.arctan2(-R[bidx,1,2], R[bidx,2,2]) 117 | xyz[bidx,2] = np.arctan2(-R[bidx,0,1], R[bidx,0,0]) 118 | else: 119 | xyz[bidx,1] = -np.pi/2 120 | xyz[bidx,0] = -np.arctan2(R[bidx,1,0],R[bidx,1,1]) 121 | xyz[bidx,2] = 0 122 | else: 123 | xyz[bidx,1] = np.pi/2 124 | xyz[bidx,0] = np.arctan2(R[bidx,1,0], R[bidx,1,1]) 125 | xyz[bidx,2] = 0 126 | return xyz.squeeze() 127 | 128 | def zyx_from_rotm(R): 129 | R = R.reshape(-1,3,3) 130 | zyx = np.empty((R.shape[0],3), dtype=R.dtype) 131 | for bidx in range(R.shape[0]): 132 | if R[bidx,2,0] < 1: 133 | if R[bidx,2,0] > -1: 134 | zyx[bidx,1] = np.arcsin(-R[bidx,2,0]) 135 | zyx[bidx,0] = np.arctan2(R[bidx,1,0], R[bidx,0,0]) 136 | zyx[bidx,2] = np.arctan2(R[bidx,2,1], R[bidx,2,2]) 137 | else: 138 | zyx[bidx,1] = np.pi / 2 139 | zyx[bidx,0] = -np.arctan2(-R[bidx,1,2], R[bidx,1,1]) 140 | zyx[bidx,2] = 0 141 | else: 142 | zyx[bidx,1] = -np.pi / 2 143 | zyx[bidx,0] = np.arctan2(-R[bidx,1,2], R[bidx,1,1]) 144 | zyx[bidx,2] = 0 145 | return zyx.squeeze() 146 | 147 | def rotm_from_xyz(xyz): 148 | xyz = np.array(xyz, copy=False).reshape(-1,3) 149 | return (rot_x(xyz[:,0]) @ rot_y(xyz[:,1]) @ rot_z(xyz[:,2])).squeeze() 150 | 151 | def rotm_from_zyx(zyx): 152 | zyx = np.array(zyx, copy=False).reshape(-1,3) 153 | return (rot_z(zyx[:,0]) @ rot_y(zyx[:,1]) @ rot_x(zyx[:,2])).squeeze() 154 | 155 | def rotm_from_quat(q): 156 | q = q.reshape(-1,4) 157 | w, x, y, z = q[:,0], q[:,1], q[:,2], q[:,3] 158 | R = np.array([ 159 | [1 - 2*y*y - 2*z*z, 2*x*y - 2*z*w, 2*x*z + 2*y*w], 160 | [2*x*y + 2*z*w, 1 - 2*x*x - 2*z*z, 2*y*z - 2*x*w], 161 | [2*x*z - 2*y*w, 2*y*z + 2*x*w, 1 - 2*x*x - 2*y*y] 162 | ], dtype=q.dtype) 163 | R = R.transpose((2,0,1)) 164 | return R.squeeze() 165 | 166 | def rotm_from_axisangle(a): 167 | # exponential 168 | a = a.reshape(-1,3) 169 | phi = np.linalg.norm(a, axis=1).reshape(-1,1,1) 170 | iphi = np.zeros_like(phi) 171 | np.divide(1, phi, out=iphi, where=phi != 0) 172 | A = cross_prod_mat(a) * iphi 173 | R = np.eye(3, dtype=a.dtype) + np.sin(phi) * A + (1 - np.cos(phi)) * A @ A 174 | return R.squeeze() 175 | 176 | def rotm_from_lookat(dir, up=None): 177 | dir = dir.reshape(-1,3) 178 | if up is None: 179 | up = np.zeros_like(dir) 180 | up[:,1] = 1 181 | dir /= np.linalg.norm(dir, axis=1, keepdims=True) 182 | up /= np.linalg.norm(up, axis=1, keepdims=True) 183 | x = dir[:,None,:] @ cross_prod_mat(up).transpose(0,2,1) 184 | y = x @ cross_prod_mat(dir).transpose(0,2,1) 185 | x = x.squeeze() 186 | y = y.squeeze() 187 | x /= np.linalg.norm(x, axis=1, keepdims=True) 188 | y /= np.linalg.norm(y, axis=1, keepdims=True) 189 | R = np.empty((dir.shape[0],3,3), dtype=dir.dtype) 190 | R[:,0,0] = x[:,0] 191 | R[:,0,1] = y[:,0] 192 | R[:,0,2] = dir[:,0] 193 | R[:,1,0] = x[:,1] 194 | R[:,1,1] = y[:,1] 195 | R[:,1,2] = dir[:,1] 196 | R[:,2,0] = x[:,2] 197 | R[:,2,1] = y[:,2] 198 | R[:,2,2] = dir[:,2] 199 | return R.transpose(0,2,1).squeeze() 200 | 201 | def rotm_distance_identity(R0, R1): 202 | # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2 203 | # in [0, 2*sqrt(2)] 204 | R0 = R0.reshape(-1,3,3) 205 | R1 = R1.reshape(-1,3,3) 206 | dists = np.linalg.norm(np.eye(3,dtype=R0.dtype) - R0 @ R1.transpose(0,2,1), axis=(1,2)) 207 | return dists.squeeze() 208 | 209 | def rotm_distance_geodesic(R0, R1): 210 | # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2 211 | # in [0, pi) 212 | R0 = R0.reshape(-1,3,3) 213 | R1 = R1.reshape(-1,3,3) 214 | RtR = R0 @ R1.transpose(0,2,1) 215 | aa = axisangle_from_rotm(RtR) 216 | S = cross_prod_mat(aa).reshape(-1,3,3) 217 | dists = np.linalg.norm(S, axis=(1,2)) 218 | return dists.squeeze() 219 | 220 | 221 | 222 | def axisangle_from_rotm(R): 223 | # logarithm of rotation matrix 224 | # R = R.reshape(-1,3,3) 225 | # tr = np.trace(R, axis1=1, axis2=2) 226 | # phi = np.arccos(np.clip((tr - 1) / 2, -1, 1)) 227 | # scale = np.zeros_like(phi) 228 | # div = 2 * np.sin(phi) 229 | # np.divide(phi, div, out=scale, where=np.abs(div) > 1e-6) 230 | # A = (R - R.transpose(0,2,1)) * scale.reshape(-1,1,1) 231 | # aa = np.stack((A[:,2,1], A[:,0,2], A[:,1,0]), axis=1) 232 | # return aa.squeeze() 233 | R = R.reshape(-1,3,3) 234 | omega = np.empty((R.shape[0], 3), dtype=R.dtype) 235 | omega[:,0] = R[:,2,1] - R[:,1,2] 236 | omega[:,1] = R[:,0,2] - R[:,2,0] 237 | omega[:,2] = R[:,1,0] - R[:,0,1] 238 | r = np.linalg.norm(omega, axis=1).reshape(-1,1) 239 | t = np.trace(R, axis1=1, axis2=2).reshape(-1,1) 240 | omega = np.arctan2(r, t-1) * omega 241 | aa = np.zeros_like(omega) 242 | np.divide(omega, r, out=aa, where=r != 0) 243 | return aa.squeeze() 244 | 245 | def axisangle_from_quat(q): 246 | q = q.reshape(-1,4) 247 | phi = 2 * np.arccos(q[:,0]) 248 | denom = np.zeros_like(q[:,0]) 249 | np.divide(1, np.sqrt(1 - q[:,0]**2), out=denom, where=q[:,0] != 1) 250 | axis = q[:,1:] * denom.reshape(-1,1) 251 | denom = np.linalg.norm(axis, axis=1).reshape(-1,1) 252 | a = np.zeros_like(axis) 253 | np.divide(phi.reshape(-1,1) * axis, denom, out=a, where=denom != 0) 254 | aa = a.astype(q.dtype) 255 | return aa.squeeze() 256 | 257 | def axisangle_apply(aa, x): 258 | # working only with single aa and single x at the moment 259 | xshape = x.shape 260 | aa = aa.reshape(3,) 261 | x = x.reshape(3,) 262 | phi = np.linalg.norm(aa) 263 | e = np.zeros_like(aa) 264 | np.divide(aa, phi, out=e, where=phi != 0) 265 | xr = np.cos(phi) * x + np.sin(phi) * np.cross(e, x) + (1 - np.cos(phi)) * (e.T @ x) * e 266 | return xr.reshape(xshape) 267 | 268 | 269 | def exp_so3(R): 270 | w = axisangle_from_rotm(R) 271 | return w 272 | 273 | def log_so3(w): 274 | R = rotm_from_axisangle(w) 275 | return R 276 | 277 | def exp_se3(R, t): 278 | R = R.reshape(-1,3,3) 279 | t = t.reshape(-1,3) 280 | 281 | w = exp_so3(R).reshape(-1,3) 282 | 283 | phi = np.linalg.norm(w, axis=1).reshape(-1,1,1) 284 | A = cross_prod_mat(w) 285 | Vi = np.eye(3, dtype=R.dtype) - A/2 + (1 - (phi * np.sin(phi) / (2 * (1 - np.cos(phi))))) / phi**2 * A @ A 286 | u = t.reshape(-1,1,3) @ Vi.transpose(0,2,1) 287 | 288 | # v = (u, w) 289 | v = np.empty((R.shape[0],6), dtype=R.dtype) 290 | v[:,:3] = u.squeeze() 291 | v[:,3:] = w 292 | 293 | return v.squeeze() 294 | 295 | def log_se3(v): 296 | # v = (u, w) 297 | v = v.reshape(-1,6) 298 | u = v[:,:3] 299 | w = v[:,3:] 300 | 301 | R = log_so3(w) 302 | 303 | phi = np.linalg.norm(w, axis=1).reshape(-1,1,1) 304 | A = cross_prod_mat(w) 305 | V = np.eye(3, dtype=v.dtype) + (1 - np.cos(phi)) / phi**2 * A + (phi - np.sin(phi)) / phi**3 * A @ A 306 | t = u.reshape(-1,1,3) @ V.transpose(0,2,1) 307 | 308 | return R.squeeze(), t.squeeze() 309 | 310 | 311 | def quat_from_rotm(R): 312 | R = R.reshape(-1,3,3) 313 | q = np.empty((R.shape[0], 4,), dtype=R.dtype) 314 | q[:,0] = np.sqrt( np.maximum(0, 1 + R[:,0,0] + R[:,1,1] + R[:,2,2]) ) 315 | q[:,1] = np.sqrt( np.maximum(0, 1 + R[:,0,0] - R[:,1,1] - R[:,2,2]) ) 316 | q[:,2] = np.sqrt( np.maximum(0, 1 - R[:,0,0] + R[:,1,1] - R[:,2,2]) ) 317 | q[:,3] = np.sqrt( np.maximum(0, 1 - R[:,0,0] - R[:,1,1] + R[:,2,2]) ) 318 | q[:,1] *= np.sign(q[:,1] * (R[:,2,1] - R[:,1,2])) 319 | q[:,2] *= np.sign(q[:,2] * (R[:,0,2] - R[:,2,0])) 320 | q[:,3] *= np.sign(q[:,3] * (R[:,1,0] - R[:,0,1])) 321 | q /= np.linalg.norm(q,axis=1,keepdims=True) 322 | return q.squeeze() 323 | 324 | def quat_from_axisangle(a): 325 | a = a.reshape(-1, 3) 326 | phi = np.linalg.norm(a, axis=1) 327 | iphi = np.zeros_like(phi) 328 | np.divide(1, phi, out=iphi, where=phi != 0) 329 | a = a * iphi.reshape(-1,1) 330 | theta = phi / 2.0 331 | r = np.cos(theta) 332 | stheta = np.sin(theta) 333 | q = np.stack((r, stheta*a[:,0], stheta*a[:,1], stheta*a[:,2]), axis=1) 334 | q /= np.linalg.norm(q, axis=1).reshape(-1,1) 335 | return q.squeeze() 336 | 337 | def quat_identity(n=1, dtype=np.float32): 338 | q = np.zeros((n,4), dtype=dtype) 339 | q[:,0] = 1 340 | return q.squeeze() 341 | 342 | def quat_conjugate(q): 343 | shape = q.shape 344 | q = q.reshape(-1,4).copy() 345 | q[:,1:] *= -1 346 | return q.reshape(shape) 347 | 348 | def quat_product(q1, q2): 349 | # q1 . q2 is equivalent to R(q1) @ R(q2) 350 | shape = q1.shape 351 | q1, q2 = q1.reshape(-1,4), q2.reshape(-1, 4) 352 | q = np.empty((max(q1.shape[0], q2.shape[0]), 4), dtype=q1.dtype) 353 | a1,b1,c1,d1 = q1[:,0], q1[:,1], q1[:,2], q1[:,3] 354 | a2,b2,c2,d2 = q2[:,0], q2[:,1], q2[:,2], q2[:,3] 355 | q[:,0] = a1 * a2 - b1 * b2 - c1 * c2 - d1 * d2 356 | q[:,1] = a1 * b2 + b1 * a2 + c1 * d2 - d1 * c2 357 | q[:,2] = a1 * c2 - b1 * d2 + c1 * a2 + d1 * b2 358 | q[:,3] = a1 * d2 + b1 * c2 - c1 * b2 + d1 * a2 359 | return q.squeeze() 360 | 361 | def quat_apply(q, x): 362 | xshape = x.shape 363 | x = x.reshape(-1, 3) 364 | qshape = q.shape 365 | q = q.reshape(-1, 4) 366 | 367 | p = np.empty((x.shape[0], 4), dtype=x.dtype) 368 | p[:,0] = 0 369 | p[:,1:] = x 370 | 371 | r = quat_product(quat_product(q, p), quat_conjugate(q)) 372 | if r.ndim == 1: 373 | return r[1:].reshape(xshape) 374 | else: 375 | return r[:,1:].reshape(xshape) 376 | 377 | 378 | def quat_random(rng=None, n=1): 379 | # http://planning.cs.uiuc.edu/node198.html 380 | if rng is not None: 381 | u = rng.uniform(0, 1, size=(3,n)) 382 | else: 383 | u = np.random.uniform(0, 1, size=(3,n)) 384 | q = np.array(( 385 | np.sqrt(1 - u[0]) * np.sin(2 * np.pi * u[1]), 386 | np.sqrt(1 - u[0]) * np.cos(2 * np.pi * u[1]), 387 | np.sqrt(u[0]) * np.sin(2 * np.pi * u[2]), 388 | np.sqrt(u[0]) * np.cos(2 * np.pi * u[2]) 389 | )).T 390 | q /= np.linalg.norm(q,axis=1,keepdims=True) 391 | return q.squeeze() 392 | 393 | def quat_distance_angle(q0, q1): 394 | # https://math.stackexchange.com/questions/90081/quaternion-distance 395 | # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2 396 | q0 = q0.reshape(-1,4) 397 | q1 = q1.reshape(-1,4) 398 | dists = np.arccos(np.clip(2 * np.sum(q0 * q1, axis=1)**2 - 1, -1, 1)) 399 | return dists 400 | 401 | def quat_distance_normdiff(q0, q1): 402 | # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2 403 | # \phi_4 404 | # [0, 1] 405 | q0 = q0.reshape(-1,4) 406 | q1 = q1.reshape(-1,4) 407 | return 1 - np.sum(q0 * q1, axis=1)**2 408 | 409 | def quat_distance_mineucl(q0, q1): 410 | # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2 411 | # http://users.cecs.anu.edu.au/~trumpf/pubs/Hartley_Trumpf_Dai_Li.pdf 412 | q0 = q0.reshape(-1,4) 413 | q1 = q1.reshape(-1,4) 414 | diff0 = ((q0 - q1)**2).sum(axis=1) 415 | diff1 = ((q0 + q1)**2).sum(axis=1) 416 | return np.minimum(diff0, diff1) 417 | 418 | def quat_slerp_space(q0, q1, num=100, endpoint=True): 419 | q0 = q0.ravel() 420 | q1 = q1.ravel() 421 | dot = q0.dot(q1) 422 | if dot < 0: 423 | q1 *= -1 424 | dot *= -1 425 | t = np.linspace(0, 1, num=num, endpoint=endpoint, dtype=q0.dtype) 426 | t = t.reshape((-1,1)) 427 | if dot > 0.9995: 428 | ret = q0 + t * (q1 - q0) 429 | return ret 430 | dot = np.clip(dot, -1, 1) 431 | theta0 = np.arccos(dot) 432 | theta = theta0 * t 433 | s0 = np.cos(theta) - dot * np.sin(theta) / np.sin(theta0) 434 | s1 = np.sin(theta) / np.sin(theta0) 435 | return (s0 * q0) + (s1 * q1) 436 | 437 | def cart_to_spherical(x): 438 | shape = x.shape 439 | x = x.reshape(-1,3) 440 | y = np.empty_like(x) 441 | y[:,0] = np.linalg.norm(x, axis=1) # r 442 | y[:,1] = np.arccos(x[:,2] / y[:,0]) # theta 443 | y[:,2] = np.arctan2(x[:,1], x[:,0]) # phi 444 | return y.reshape(shape) 445 | 446 | def spherical_to_cart(x): 447 | shape = x.shape 448 | x = x.reshape(-1,3) 449 | y = np.empty_like(x) 450 | y[:,0] = x[:,0] * np.sin(x[:,1]) * np.cos(x[:,2]) 451 | y[:,1] = x[:,0] * np.sin(x[:,1]) * np.sin(x[:,2]) 452 | y[:,2] = x[:,0] * np.cos(x[:,1]) 453 | return y.reshape(shape) 454 | 455 | def spherical_random(r=1, n=1): 456 | # http://mathworld.wolfram.com/SpherePointPicking.html 457 | # https://math.stackexchange.com/questions/1585975/how-to-generate-random-points-on-a-sphere 458 | x = np.empty((n,3)) 459 | x[:,0] = r 460 | x[:,1] = 2 * np.pi * np.random.uniform(0,1, size=(n,)) 461 | x[:,2] = np.arccos(2 * np.random.uniform(0,1, size=(n,)) - 1) 462 | return x.squeeze() 463 | 464 | def color_pcl(pcl, K, im, color_axis=0, as_int=True, invalid_color=[0,0,0]): 465 | uvd = K @ pcl.T 466 | uvd /= uvd[2] 467 | uvd = np.round(uvd).astype(np.int32) 468 | mask = np.logical_and(uvd[0] >= 0, uvd[1] >= 0) 469 | color = np.empty((pcl.shape[0], 3), dtype=im.dtype) 470 | if color_axis == 0: 471 | mask = np.logical_and(mask, uvd[0] < im.shape[2]) 472 | mask = np.logical_and(mask, uvd[1] < im.shape[1]) 473 | uvd = uvd[:,mask] 474 | color[mask,:] = im[:,uvd[1],uvd[0]].T 475 | elif color_axis == 2: 476 | mask = np.logical_and(mask, uvd[0] < im.shape[1]) 477 | mask = np.logical_and(mask, uvd[1] < im.shape[0]) 478 | uvd = uvd[:,mask] 479 | color[mask,:] = im[uvd[1],uvd[0], :] 480 | else: 481 | raise Exception('invalid color_axis') 482 | color[np.logical_not(mask),:3] = invalid_color 483 | if as_int: 484 | color = (255.0 * color).astype(np.int32) 485 | return color 486 | 487 | def center_pcl(pcl, robust=False, copy=False, axis=1): 488 | if copy: 489 | pcl = pcl.copy() 490 | if robust: 491 | mu = np.median(pcl, axis=axis, keepdims=True) 492 | else: 493 | mu = np.mean(pcl, axis=axis, keepdims=True) 494 | return pcl - mu 495 | 496 | def to_homogeneous(x): 497 | # return np.hstack((x, np.ones((x.shape[0],1),dtype=x.dtype))) 498 | return np.concatenate((x, np.ones((*x.shape[:-1],1),dtype=x.dtype)), axis=-1) 499 | 500 | def from_homogeneous(x): 501 | return x[:,:-1] / x[:,-1] 502 | 503 | def project_uvn(uv, Ki=None): 504 | if uv.shape[1] == 2: 505 | uvn = to_homogeneous(uv) 506 | else: 507 | uvn = uv 508 | if uvn.shape[1] != 3: 509 | raise Exception('uv should have shape Nx2 or Nx3') 510 | if Ki is None: 511 | return uvn 512 | else: 513 | return uvn @ Ki.T 514 | 515 | def project_uvd(uv, depth, K=np.eye(3), R=np.eye(3), t=np.zeros((3,1)), ignore_negative_depth=True, return_uvn=False): 516 | Ki = np.linalg.inv(K) 517 | 518 | if ignore_negative_depth: 519 | mask = depth >= 0 520 | uv = uv[mask,:] 521 | d = depth[mask] 522 | else: 523 | d = depth.ravel() 524 | 525 | uv1 = to_homogeneous(uv) 526 | 527 | uvn1 = uv1 @ Ki.T 528 | xyz = d.reshape(-1,1) * uvn1 529 | xyz = (xyz - t.reshape((1,3))) @ R 530 | 531 | if return_uvn: 532 | return xyz, uvn1 533 | else: 534 | return xyz 535 | 536 | def project_depth(depth, K, R=np.eye(3,3), t=np.zeros((3,1)), ignore_negative_depth=True, return_uvn=False): 537 | u, v = np.meshgrid(range(depth.shape[1]), range(depth.shape[0])) 538 | uv = np.hstack((u.reshape(-1,1), v.reshape(-1,1))) 539 | return project_uvd(uv, depth.ravel(), K, R, t, ignore_negative_depth, return_uvn) 540 | 541 | 542 | def project_xyz(xyz, K=np.eye(3), R=np.eye(3,3), t=np.zeros((3,1))): 543 | uvd = K @ (R @ xyz.T + t.reshape((3,1))) 544 | uvd[:2] /= uvd[2] 545 | return uvd[:2].T, uvd[2] 546 | 547 | 548 | def relative_motion(R0, t0, R1, t1, Rt_from_global=True): 549 | t0 = t0.reshape((3,1)) 550 | t1 = t1.reshape((3,1)) 551 | if Rt_from_global: 552 | Rr = R1 @ R0.T 553 | tr = t1 - Rr @ t0 554 | else: 555 | Rr = R1.T @ R0 556 | tr = R1.T @ (t0 - t1) 557 | return Rr, tr.ravel() 558 | 559 | 560 | def translation_to_cameracenter(R, t): 561 | t = t.reshape(-1,3,1) 562 | R = R.reshape(-1,3,3) 563 | C = -R.transpose(0,2,1) @ t 564 | return C.squeeze() 565 | 566 | def cameracenter_to_translation(R, C): 567 | C = C.reshape(-1,3,1) 568 | R = R.reshape(-1,3,3) 569 | t = -R @ C 570 | return t.squeeze() 571 | 572 | def decompose_projection_matrix(P, return_t=True): 573 | if P.shape[0] != 3 or P.shape[1] != 4: 574 | raise Exception('P has to be 3x4') 575 | M = P[:, :3] 576 | C = -np.linalg.inv(M) @ P[:, 3:] 577 | 578 | R,K = np.linalg.qr(np.flipud(M).T) 579 | K = np.flipud(K.T) 580 | K = np.fliplr(K) 581 | R = np.flipud(R.T) 582 | 583 | T = np.diag(np.sign(np.diag(K))) 584 | K = K @ T 585 | R = T @ R 586 | 587 | if np.linalg.det(R) < 0: 588 | R *= -1 589 | 590 | K /= K[2,2] 591 | if return_t: 592 | return K, R, cameracenter_to_translation(R, C) 593 | else: 594 | return K, R, C 595 | 596 | 597 | def compose_projection_matrix(K=np.eye(3), R=np.eye(3,3), t=np.zeros((3,1))): 598 | return K @ np.hstack((R, t.reshape((3,1)))) 599 | 600 | 601 | 602 | def point_plane_distance(pts, plane): 603 | pts = pts.reshape(-1,3) 604 | return np.abs(np.sum(plane[:3] * pts, axis=1) + plane[3]) / np.linalg.norm(plane[:3]) 605 | 606 | def fit_plane(pts): 607 | pts = pts.reshape(-1,3) 608 | center = np.mean(pts, axis=0) 609 | A = pts - center 610 | u, s, vh = np.linalg.svd(A, full_matrices=False) 611 | plane = np.array([*vh[2], -vh[2].dot(center)]) 612 | return plane 613 | 614 | def tetrahedron(dtype=np.float32): 615 | verts = np.array([ 616 | (np.sqrt(8/9), 0, -1/3), (-np.sqrt(2/9), np.sqrt(2/3), -1/3), 617 | (-np.sqrt(2/9), -np.sqrt(2/3), -1/3), (0, 0, 1)], dtype=dtype) 618 | faces = np.array([(0,1,2), (0,2,3), (0,1,3), (1,2,3)], dtype=np.int32) 619 | normals = -np.mean(verts, axis=0) + verts 620 | normals /= np.linalg.norm(normals, axis=1).reshape(-1,1) 621 | return verts, faces, normals 622 | 623 | def cube(dtype=np.float32): 624 | verts = np.array([ 625 | [-0.5,-0.5,-0.5], [-0.5,0.5,-0.5], [0.5,0.5,-0.5], [0.5,-0.5,-0.5], 626 | [-0.5,-0.5,0.5], [-0.5,0.5,0.5], [0.5,0.5,0.5], [0.5,-0.5,0.5]], dtype=dtype) 627 | faces = np.array([ 628 | (0,1,2), (0,2,3), (4,5,6), (4,6,7), 629 | (0,4,7), (0,7,3), (1,5,6), (1,6,2), 630 | (3,2,6), (3,6,7), (0,1,5), (0,5,4)], dtype=np.int32) 631 | normals = -np.mean(verts, axis=0) + verts 632 | normals /= np.linalg.norm(normals, axis=1).reshape(-1,1) 633 | return verts, faces, normals 634 | 635 | def octahedron(dtype=np.float32): 636 | verts = np.array([ 637 | (+1,0,0), (0,+1,0), (0,0,+1), 638 | (-1,0,0), (0,-1,0), (0,0,-1)], dtype=dtype) 639 | faces = np.array([ 640 | (0,1,2), (1,2,3), (3,2,4), (4,2,0), 641 | (0,1,5), (1,5,3), (3,5,4), (4,5,0)], dtype=np.int32) 642 | normals = -np.mean(verts, axis=0) + verts 643 | normals /= np.linalg.norm(normals, axis=1).reshape(-1,1) 644 | return verts, faces, normals 645 | 646 | def icosahedron(dtype=np.float32): 647 | p = (1 + np.sqrt(5)) / 2 648 | verts = np.array([ 649 | (-1,0,p), (1,0,p), (1,0,-p), (-1,0,-p), 650 | (0,-p,1), (0,p,1), (0,p,-1), (0,-p,-1), 651 | (-p,-1,0), (p,-1,0), (p,1,0), (-p,1,0) 652 | ], dtype=dtype) 653 | faces = np.array([ 654 | (0,1,4), (0,1,5), (1,4,9), (1,9,10), (1,10,5), (0,4,8), (0,8,11), (0,11,5), 655 | (5,6,11), (5,6,10), (4,7,8), (4,7,9), 656 | (3,2,6), (3,2,7), (2,6,10), (2,10,9), (2,9,7), (3,6,11), (3,11,8), (3,8,7), 657 | ], dtype=np.int32) 658 | normals = -np.mean(verts, axis=0) + verts 659 | normals /= np.linalg.norm(normals, axis=1).reshape(-1,1) 660 | return verts, faces, normals 661 | 662 | def xyplane(dtype=np.float32, z=0, interleaved=False): 663 | if interleaved: 664 | eps = 1e-6 665 | verts = np.array([ 666 | (-1,-1,z), (-1,1,z), (1,1,z), 667 | (1-eps,1,z), (1-eps,-1,z), (-1-eps,-1,z)], dtype=dtype) 668 | faces = np.array([(0,1,2), (3,4,5)], dtype=np.int32) 669 | else: 670 | verts = np.array([(-1,-1,z), (-1,1,z), (1,1,z), (1,-1,z)], dtype=dtype) 671 | faces = np.array([(0,1,2), (0,2,3)], dtype=np.int32) 672 | normals = np.zeros_like(verts) 673 | normals[:,2] = -1 674 | return verts, faces, normals 675 | 676 | def mesh_independent_verts(verts, faces, normals=None): 677 | new_verts = [] 678 | new_normals = [] 679 | for f in faces: 680 | new_verts.append(verts[f[0]]) 681 | new_verts.append(verts[f[1]]) 682 | new_verts.append(verts[f[2]]) 683 | if normals is not None: 684 | new_normals.append(normals[f[0]]) 685 | new_normals.append(normals[f[1]]) 686 | new_normals.append(normals[f[2]]) 687 | new_verts = np.array(new_verts) 688 | new_faces = np.arange(0, faces.size, dtype=faces.dtype).reshape(-1,3) 689 | if normals is None: 690 | return new_verts, new_faces 691 | else: 692 | new_normals = np.array(new_normals) 693 | return new_verts, new_faces, new_normals 694 | 695 | 696 | def stack_mesh(verts, faces): 697 | n_verts = 0 698 | mfaces = [] 699 | for idx, f in enumerate(faces): 700 | mfaces.append(f + n_verts) 701 | n_verts += verts[idx].shape[0] 702 | verts = np.vstack(verts) 703 | faces = np.vstack(mfaces) 704 | return verts, faces 705 | 706 | def normalize_mesh(verts): 707 | # all the verts have unit distance to the center (0,0,0) 708 | return verts / np.linalg.norm(verts, axis=1, keepdims=True) 709 | 710 | 711 | def mesh_triangle_areas(verts, faces): 712 | a = verts[faces[:,0]] 713 | b = verts[faces[:,1]] 714 | c = verts[faces[:,2]] 715 | x = np.empty_like(a) 716 | x = a - b 717 | y = a - c 718 | t = np.empty_like(a) 719 | t[:,0] = (x[:,1] * y[:,2] - x[:,2] * y[:,1]); 720 | t[:,1] = (x[:,2] * y[:,0] - x[:,0] * y[:,2]); 721 | t[:,2] = (x[:,0] * y[:,1] - x[:,1] * y[:,0]); 722 | return np.linalg.norm(t, axis=1) / 2 723 | 724 | def subdivde_mesh(verts_in, faces_in, n=1): 725 | for iter in range(n): 726 | verts = [] 727 | for v in verts_in: 728 | verts.append(v) 729 | faces = [] 730 | verts_dict = {} 731 | for f in faces_in: 732 | f = np.sort(f) 733 | i0,i1,i2 = f 734 | v0,v1,v2 = verts_in[f] 735 | 736 | k = i0*len(verts_in)+i1 737 | if k in verts_dict: 738 | i01 = verts_dict[k] 739 | else: 740 | i01 = len(verts) 741 | verts_dict[k] = i01 742 | v01 = (v0 + v1) / 2 743 | verts.append(v01) 744 | 745 | k = i0*len(verts_in)+i2 746 | if k in verts_dict: 747 | i02 = verts_dict[k] 748 | else: 749 | i02 = len(verts) 750 | verts_dict[k] = i02 751 | v02 = (v0 + v2) / 2 752 | verts.append(v02) 753 | 754 | k = i1*len(verts_in)+i2 755 | if k in verts_dict: 756 | i12 = verts_dict[k] 757 | else: 758 | i12 = len(verts) 759 | verts_dict[k] = i12 760 | v12 = (v1 + v2) / 2 761 | verts.append(v12) 762 | 763 | faces.append((i0,i01,i02)) 764 | faces.append((i01,i1,i12)) 765 | faces.append((i12,i2,i02)) 766 | faces.append((i01,i12,i02)) 767 | 768 | verts_in = np.array(verts, dtype=verts_in.dtype) 769 | faces_in = np.array(faces, dtype=np.int32) 770 | return verts_in, faces_in 771 | 772 | 773 | def mesh_adjust_winding_order(verts, faces, normals): 774 | n0 = normals[faces[:,0]] 775 | n1 = normals[faces[:,1]] 776 | n2 = normals[faces[:,2]] 777 | fnormals = (n0 + n1 + n2) / 3 778 | 779 | v0 = verts[faces[:,0]] 780 | v1 = verts[faces[:,1]] 781 | v2 = verts[faces[:,2]] 782 | 783 | e0 = v1 - v0 784 | e1 = v2 - v0 785 | fn = np.cross(e0, e1) 786 | 787 | dot = np.sum(fnormals * fn, axis=1) 788 | ma = dot < 0 789 | 790 | nfaces = faces.copy() 791 | nfaces[ma,1], nfaces[ma,2] = nfaces[ma,2], nfaces[ma,1] 792 | 793 | return nfaces 794 | 795 | 796 | def pcl_to_shapecl(verts, colors=None, shape='cube', width=1.0): 797 | if shape == 'tetrahedron': 798 | cverts, cfaces, _ = tetrahedron() 799 | elif shape == 'cube': 800 | cverts, cfaces, _ = cube() 801 | elif shape == 'octahedron': 802 | cverts, cfaces, _ = octahedron() 803 | elif shape == 'icosahedron': 804 | cverts, cfaces, _ = icosahedron() 805 | else: 806 | raise Exception('invalid shape') 807 | 808 | sverts = np.tile(cverts, (verts.shape[0], 1)) 809 | sverts *= width 810 | sverts += np.repeat(verts, cverts.shape[0], axis=0) 811 | 812 | sfaces = np.tile(cfaces, (verts.shape[0], 1)) 813 | sfoffset = cverts.shape[0] * np.arange(0, verts.shape[0]) 814 | sfaces += np.repeat(sfoffset, cfaces.shape[0]).reshape(-1,1) 815 | 816 | if colors is not None: 817 | scolors = np.repeat(colors, cverts.shape[0], axis=0) 818 | else: 819 | scolors = None 820 | 821 | return sverts, sfaces, scolors 822 | -------------------------------------------------------------------------------- /co/gtimer.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # MIT License 5 | # 6 | # Copyright (c) 2019 autonomousvision 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | import numpy as np 27 | 28 | from . import utils 29 | 30 | class StopWatch(utils.StopWatch): 31 | def __del__(self): 32 | print('='*80) 33 | print('gtimer:') 34 | total = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.sum).items()]) 35 | print(f' [total] {total}') 36 | mean = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.mean).items()]) 37 | print(f' [mean] {mean}') 38 | median = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.median).items()]) 39 | print(f' [median] {median}') 40 | print('='*80) 41 | 42 | GTIMER = StopWatch() 43 | 44 | def start(name): 45 | GTIMER.start(name) 46 | def stop(name): 47 | GTIMER.stop(name) 48 | 49 | class Ctx(object): 50 | def __init__(self, name): 51 | self.name = name 52 | 53 | def __enter__(self): 54 | start(self.name) 55 | 56 | def __exit__(self, *args): 57 | stop(self.name) 58 | -------------------------------------------------------------------------------- /co/io3d.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # MIT License 5 | # 6 | # Copyright (c) 2019 autonomousvision 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | import struct 27 | import numpy as np 28 | import collections 29 | 30 | def _write_ply_point(fp, x,y,z, color=None, normal=None, binary=False): 31 | args = [x,y,z] 32 | if color is not None: 33 | args += [int(color[0]), int(color[1]), int(color[2])] 34 | if normal is not None: 35 | args += [normal[0],normal[1],normal[2]] 36 | if binary: 37 | fmt = ' 1: 101 | c = color[vidx] 102 | else: 103 | c = color[0] 104 | else: 105 | c = None 106 | if normals is None: 107 | n = None 108 | else: 109 | n = normals[vidx] 110 | _write_ply_point(fp, v[0],v[1],v[2], c, n, binary) 111 | 112 | if trias is not None: 113 | for t in trias: 114 | _write_ply_triangle(fp, t[0],t[1],t[2], binary) 115 | 116 | def faces_to_triangles(faces): 117 | new_faces = [] 118 | for f in faces: 119 | if f[0] == 3: 120 | new_faces.append([f[1], f[2], f[3]]) 121 | elif f[0] == 4: 122 | new_faces.append([f[1], f[2], f[3]]) 123 | new_faces.append([f[3], f[4], f[1]]) 124 | else: 125 | raise Exception('unknown face count %d', f[0]) 126 | return new_faces 127 | 128 | def read_ply(path): 129 | with open(path, 'rb') as f: 130 | # parse header 131 | line = f.readline().decode().strip() 132 | if line != 'ply': 133 | raise Exception('Header error') 134 | n_verts = 0 135 | n_faces = 0 136 | vert_types = {} 137 | vert_bin_format = [] 138 | vert_bin_len = 0 139 | vert_bin_cols = 0 140 | line = f.readline().decode() 141 | parse_vertex_prop = False 142 | while line.strip() != 'end_header': 143 | if 'format' in line: 144 | if 'ascii' in line: 145 | binary = False 146 | elif 'binary_little_endian' in line: 147 | binary = True 148 | else: 149 | raise Exception('invalid ply format') 150 | if 'element face' in line: 151 | splits = line.strip().split(' ') 152 | n_faces = int(splits[-1]) 153 | parse_vertex_prop = False 154 | if 'element camera' in line: 155 | parse_vertex_prop = False 156 | if 'element vertex' in line: 157 | splits = line.strip().split(' ') 158 | n_verts = int(splits[-1]) 159 | parse_vertex_prop = True 160 | if parse_vertex_prop and 'property' in line: 161 | prop = line.strip().split() 162 | if prop[1] == 'float': 163 | vert_bin_format.append('f4') 164 | vert_bin_len += 4 165 | vert_bin_cols += 1 166 | elif prop[1] == 'uchar': 167 | vert_bin_format.append('B') 168 | vert_bin_len += 1 169 | vert_bin_cols += 1 170 | else: 171 | raise Exception('invalid property') 172 | vert_types[prop[2]] = len(vert_types) 173 | line = f.readline().decode() 174 | 175 | # parse content 176 | if binary: 177 | sz = n_verts * vert_bin_len 178 | fmt = ','.join(vert_bin_format) 179 | verts = np.ndarray(shape=(1, n_verts), dtype=np.dtype(fmt), buffer=f.read(sz)) 180 | verts = verts[0].astype(vert_bin_cols*'f4,').view(dtype='f4').reshape((n_verts,-1)) 181 | faces = [] 182 | for idx in range(n_faces): 183 | fmt = '= 2 and len(parts[1]) > 0: 223 | tidx = int(parts[1]) - 1 224 | else: 225 | tidx = -1 226 | if len(parts) >= 3 and len(parts[2]) > 0: 227 | nidx = int(parts[2]) - 1 228 | else: 229 | nidx = -1 230 | return vidx, tidx, nidx 231 | 232 | def read_obj(path): 233 | with open(path, 'r') as fp: 234 | lines = fp.readlines() 235 | 236 | verts = [] 237 | colors = [] 238 | fnorms = [] 239 | fnorm_map = collections.defaultdict(list) 240 | faces = [] 241 | for line in lines: 242 | line = line.strip() 243 | if line.startswith('#') or len(line) == 0: 244 | continue 245 | 246 | parts = line.split() 247 | if line.startswith('v '): 248 | parts = parts[1:] 249 | x,y,z = float(parts[0]), float(parts[1]), float(parts[2]) 250 | if len(parts) == 4 or len(parts) == 7: 251 | w = float(parts[3]) 252 | x,y,z = x/w, y/w, z/w 253 | verts.append((x,y,z)) 254 | if len(parts) >= 6: 255 | r,g,b = float(parts[-3]), float(parts[-2]), float(parts[-1]) 256 | rgb.append((r,g,b)) 257 | 258 | elif line.startswith('vn '): 259 | parts = parts[1:] 260 | x,y,z = float(parts[0]), float(parts[1]), float(parts[2]) 261 | fnorms.append((x,y,z)) 262 | 263 | elif line.startswith('f '): 264 | parts = parts[1:] 265 | if len(parts) != 3: 266 | raise Exception('only triangle meshes supported atm') 267 | vidx0, tidx0, nidx0 = _read_obj_split_f(parts[0]) 268 | vidx1, tidx1, nidx1 = _read_obj_split_f(parts[1]) 269 | vidx2, tidx2, nidx2 = _read_obj_split_f(parts[2]) 270 | 271 | faces.append((vidx0, vidx1, vidx2)) 272 | if nidx0 >= 0: 273 | fnorm_map[vidx0].append( nidx0 ) 274 | if nidx1 >= 0: 275 | fnorm_map[vidx1].append( nidx1 ) 276 | if nidx2 >= 0: 277 | fnorm_map[vidx2].append( nidx2 ) 278 | 279 | verts = np.array(verts) 280 | colors = np.array(colors) 281 | fnorms = np.array(fnorms) 282 | faces = np.array(faces) 283 | 284 | # face normals to vertex normals 285 | norms = np.zeros_like(verts) 286 | for vidx in fnorm_map.keys(): 287 | ind = fnorm_map[vidx] 288 | norms[vidx] = fnorms[ind].sum(axis=0) 289 | N = np.linalg.norm(norms, axis=1, keepdims=True) 290 | np.divide(norms, N, out=norms, where=N != 0) 291 | 292 | return verts, faces, colors, norms 293 | -------------------------------------------------------------------------------- /co/metric.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # MIT License 5 | # 6 | # Copyright (c) 2019 autonomousvision 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | import numpy as np 27 | from . import geometry 28 | 29 | def _process_inputs(estimate, target, mask): 30 | if estimate.shape != target.shape: 31 | raise Exception('estimate and target have to be same shape') 32 | if mask is None: 33 | mask = np.ones(estimate.shape, dtype=np.bool) 34 | else: 35 | mask = mask != 0 36 | if estimate.shape != mask.shape: 37 | raise Exception('estimate and mask have to be same shape') 38 | return estimate, target, mask 39 | 40 | def mse(estimate, target, mask=None): 41 | estimate, target, mask = _process_inputs(estimate, target, mask) 42 | m = np.sum((estimate[mask] - target[mask])**2) / mask.sum() 43 | return m 44 | 45 | def rmse(estimate, target, mask=None): 46 | return np.sqrt(mse(estimate, target, mask)) 47 | 48 | def mae(estimate, target, mask=None): 49 | estimate, target, mask = _process_inputs(estimate, target, mask) 50 | m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum() 51 | return m 52 | 53 | def outlier_fraction(estimate, target, mask=None, threshold=0): 54 | estimate, target, mask = _process_inputs(estimate, target, mask) 55 | diff = np.abs(estimate[mask] - target[mask]) 56 | m = (diff > threshold).sum() / mask.sum() 57 | return m 58 | 59 | 60 | class Metric(object): 61 | def __init__(self, str_prefix=''): 62 | self.str_prefix = str_prefix 63 | self.reset() 64 | 65 | def reset(self): 66 | pass 67 | 68 | def add(self, es, ta, ma=None): 69 | pass 70 | 71 | def get(self): 72 | return {} 73 | 74 | def items(self): 75 | return self.get().items() 76 | 77 | def __str__(self): 78 | return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()]) 79 | 80 | class MultipleMetric(Metric): 81 | def __init__(self, *metrics, **kwargs): 82 | self.metrics = [*metrics] 83 | super().__init__(**kwargs) 84 | 85 | def reset(self): 86 | for m in self.metrics: 87 | m.reset() 88 | 89 | def add(self, es, ta, ma=None): 90 | for m in self.metrics: 91 | m.add(es, ta, ma) 92 | 93 | def get(self): 94 | ret = {} 95 | for m in self.metrics: 96 | vals = m.get() 97 | for k in vals: 98 | ret[k] = vals[k] 99 | return ret 100 | 101 | def __str__(self): 102 | return '\n'.join([str(m) for m in self.metrics]) 103 | 104 | class BaseDistanceMetric(Metric): 105 | def __init__(self, name='', **kwargs): 106 | super().__init__(**kwargs) 107 | self.name = name 108 | 109 | def reset(self): 110 | self.dists = [] 111 | 112 | def add(self, es, ta, ma=None): 113 | pass 114 | 115 | def get(self): 116 | dists = np.hstack(self.dists) 117 | return { 118 | f'dist{self.name}_mean': float(np.mean(dists)), 119 | f'dist{self.name}_std': float(np.std(dists)), 120 | f'dist{self.name}_median': float(np.median(dists)), 121 | f'dist{self.name}_q10': float(np.percentile(dists, 10)), 122 | f'dist{self.name}_q90': float(np.percentile(dists, 90)), 123 | f'dist{self.name}_min': float(np.min(dists)), 124 | f'dist{self.name}_max': float(np.max(dists)), 125 | } 126 | 127 | class DistanceMetric(BaseDistanceMetric): 128 | def __init__(self, vec_length, p=2, **kwargs): 129 | super().__init__(name=f'{p}', **kwargs) 130 | self.vec_length = vec_length 131 | self.p = p 132 | 133 | def add(self, es, ta, ma=None): 134 | if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2: 135 | print(es.shape, ta.shape) 136 | raise Exception('es and ta have to be of shape Nxdim') 137 | if ma is not None: 138 | es = es[ma != 0] 139 | ta = ta[ma != 0] 140 | dist = np.linalg.norm(es - ta, ord=self.p, axis=1) 141 | self.dists.append( dist ) 142 | 143 | class OutlierFractionMetric(DistanceMetric): 144 | def __init__(self, thresholds, *args, **kwargs): 145 | super().__init__(*args, **kwargs) 146 | self.thresholds = thresholds 147 | 148 | def get(self): 149 | dists = np.hstack(self.dists) 150 | ret = {} 151 | for t in self.thresholds: 152 | ma = dists > t 153 | ret[f'of{t}'] = float(ma.sum() / ma.size) 154 | return ret 155 | 156 | class RelativeDistanceMetric(BaseDistanceMetric): 157 | def __init__(self, vec_length, p=2, **kwargs): 158 | super().__init__(name=f'rel{p}', **kwargs) 159 | self.vec_length = vec_length 160 | self.p = p 161 | 162 | def add(self, es, ta, ma=None): 163 | if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2: 164 | raise Exception('es and ta have to be of shape Nxdim') 165 | dist = np.linalg.norm(es - ta, ord=self.p, axis=1) 166 | denom = np.linalg.norm(ta, ord=self.p, axis=1) 167 | dist /= denom 168 | if ma is not None: 169 | dist = dist[ma != 0] 170 | self.dists.append( dist ) 171 | 172 | class RotmDistanceMetric(BaseDistanceMetric): 173 | def __init__(self, type='identity', **kwargs): 174 | super().__init__(name=type, **kwargs) 175 | self.type = type 176 | 177 | def add(self, es, ta, ma=None): 178 | if es.shape != ta.shape or es.shape[1] != 3 or es.shape[2] != 3 or es.ndim != 3: 179 | print(es.shape, ta.shape) 180 | raise Exception('es and ta have to be of shape Nx3x3') 181 | if ma is not None: 182 | raise Exception('mask is not implemented') 183 | if self.type == 'identity': 184 | self.dists.append( geometry.rotm_distance_identity(es, ta) ) 185 | elif self.type == 'geodesic': 186 | self.dists.append( geometry.rotm_distance_geodesic_unit_sphere(es, ta) ) 187 | else: 188 | raise Exception('invalid distance type') 189 | 190 | class QuaternionDistanceMetric(BaseDistanceMetric): 191 | def __init__(self, type='angle', **kwargs): 192 | super().__init__(name=type, **kwargs) 193 | self.type = type 194 | 195 | def add(self, es, ta, ma=None): 196 | if es.shape != ta.shape or es.shape[1] != 4 or es.ndim != 2: 197 | print(es.shape, ta.shape) 198 | raise Exception('es and ta have to be of shape Nx4') 199 | if ma is not None: 200 | raise Exception('mask is not implemented') 201 | if self.type == 'angle': 202 | self.dists.append( geometry.quat_distance_angle(es, ta) ) 203 | elif self.type == 'mineucl': 204 | self.dists.append( geometry.quat_distance_mineucl(es, ta) ) 205 | elif self.type == 'normdiff': 206 | self.dists.append( geometry.quat_distance_normdiff(es, ta) ) 207 | else: 208 | raise Exception('invalid distance type') 209 | 210 | 211 | class BinaryAccuracyMetric(Metric): 212 | def __init__(self, thresholds=np.linspace(0.0, 1.0, num=101, dtype=np.float64)[:-1], **kwargs): 213 | self.thresholds = thresholds 214 | super().__init__(**kwargs) 215 | 216 | def reset(self): 217 | self.tps = [0 for wp in self.thresholds] 218 | self.fps = [0 for wp in self.thresholds] 219 | self.fns = [0 for wp in self.thresholds] 220 | self.tns = [0 for wp in self.thresholds] 221 | self.n_pos = 0 222 | self.n_neg = 0 223 | 224 | def add(self, es, ta, ma=None): 225 | if ma is not None: 226 | raise Exception('mask is not implemented') 227 | es = es.ravel() 228 | ta = ta.ravel() 229 | if es.shape[0] != ta.shape[0]: 230 | raise Exception('invalid shape of es, or ta') 231 | if es.min() < 0 or es.max() > 1: 232 | raise Exception('estimate has wrong value range') 233 | ta_p = (ta == 1) 234 | ta_n = (ta == 0) 235 | es_p = es[ta_p] 236 | es_n = es[ta_n] 237 | for idx, wp in enumerate(self.thresholds): 238 | wp = np.asscalar(wp) 239 | self.tps[idx] += (es_p > wp).sum() 240 | self.fps[idx] += (es_n > wp).sum() 241 | self.fns[idx] += (es_p <= wp).sum() 242 | self.tns[idx] += (es_n <= wp).sum() 243 | self.n_pos += ta_p.sum() 244 | self.n_neg += ta_n.sum() 245 | 246 | def get(self): 247 | tps = np.array(self.tps).astype(np.float32) 248 | fps = np.array(self.fps).astype(np.float32) 249 | fns = np.array(self.fns).astype(np.float32) 250 | tns = np.array(self.tns).astype(np.float32) 251 | wp = self.thresholds 252 | 253 | ret = {} 254 | 255 | precisions = np.divide(tps, tps + fps, out=np.zeros_like(tps), where=tps + fps != 0) 256 | recalls = np.divide(tps, tps + fns, out=np.zeros_like(tps), where=tps + fns != 0) # tprs 257 | fprs = np.divide(fps, fps + tns, out=np.zeros_like(tps), where=fps + tns != 0) 258 | 259 | precisions = np.r_[0, precisions, 1] 260 | recalls = np.r_[1, recalls, 0] 261 | fprs = np.r_[1, fprs, 0] 262 | 263 | ret['auc'] = float(-np.trapz(recalls, fprs)) 264 | ret['prauc'] = float(-np.trapz(precisions, recalls)) 265 | ret['ap'] = float(-(np.diff(recalls) * precisions[:-1]).sum()) 266 | 267 | accuracies = np.divide(tps + tns, tps + tns + fps + fns) 268 | aacc = np.mean(accuracies) 269 | for t in np.linspace(0,1,num=11)[1:-1]: 270 | idx = np.argmin(np.abs(t - wp)) 271 | ret[f'acc{wp[idx]:.2f}'] = float(accuracies[idx]) 272 | 273 | return ret 274 | -------------------------------------------------------------------------------- /co/utils.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # MIT License 5 | # 6 | # Copyright (c) 2019 autonomousvision 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | import numpy as np 27 | import time 28 | from collections import OrderedDict 29 | import argparse 30 | import subprocess 31 | 32 | def str2bool(v): 33 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 34 | return True 35 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 36 | return False 37 | else: 38 | raise argparse.ArgumentTypeError('Boolean value expected.') 39 | 40 | class StopWatch(object): 41 | def __init__(self): 42 | self.timings = OrderedDict() 43 | self.starts = {} 44 | 45 | def start(self, name): 46 | self.starts[name] = time.time() 47 | 48 | def stop(self, name): 49 | if name not in self.timings: 50 | self.timings[name] = [] 51 | self.timings[name].append(time.time() - self.starts[name]) 52 | 53 | def get(self, name=None, reduce=np.sum): 54 | if name is not None: 55 | return reduce(self.timings[name]) 56 | else: 57 | ret = {} 58 | for k in self.timings: 59 | ret[k] = reduce(self.timings[k]) 60 | return ret 61 | 62 | def __repr__(self): 63 | return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()]) 64 | def __str__(self): 65 | return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()]) 66 | 67 | class ETA(object): 68 | def __init__(self, length): 69 | self.length = length 70 | self.start_time = time.time() 71 | self.current_idx = 0 72 | self.current_time = time.time() 73 | 74 | def update(self, idx): 75 | self.current_idx = idx 76 | self.current_time = time.time() 77 | 78 | def get_elapsed_time(self): 79 | return self.current_time - self.start_time 80 | 81 | def get_item_time(self): 82 | return self.get_elapsed_time() / (self.current_idx + 1) 83 | 84 | def get_remaining_time(self): 85 | return self.get_item_time() * (self.length - self.current_idx + 1) 86 | 87 | def format_time(self, seconds): 88 | minutes, seconds = divmod(seconds, 60) 89 | hours, minutes = divmod(minutes, 60) 90 | hours = int(hours) 91 | minutes = int(minutes) 92 | return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}' 93 | 94 | def get_elapsed_time_str(self): 95 | return self.format_time(self.get_elapsed_time()) 96 | 97 | def get_remaining_time_str(self): 98 | return self.format_time(self.get_remaining_time()) 99 | 100 | def git_hash(cwd=None): 101 | ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 102 | hash = ret.stdout 103 | if hash is not None and 'fatal' not in hash.decode(): 104 | return hash.decode().strip() 105 | else: 106 | return None 107 | 108 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "OUTPUT_DIR": "/path/to/output/dir", 3 | "DATA_DIR": "/path/to/data/dir", 4 | "SHAPENET_DIR": "/path/to/ShapeNetCore.v2", 5 | "CTD_DIR": "/path/to/connecting_the_dots", 6 | "LITEFLOWNET_DIR": "/path/to/pytorch-liteflownet" 7 | } -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # Copyright (c) 2021 ams International AG 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 7 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation 8 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 9 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 14 | # OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 15 | # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 16 | # IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # MIT License 5 | # 6 | # Copyright (c) 2019 autonomousvision 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | import torch.utils.data 27 | import numpy as np 28 | 29 | class TestSet(object): 30 | def __init__(self, name, dset, test_frequency=1): 31 | self.name = name 32 | self.dset = dset 33 | self.test_frequency = test_frequency 34 | 35 | class TestSets(list): 36 | def append(self, name, dset, test_frequency=1): 37 | super().append(TestSet(name, dset, test_frequency)) 38 | 39 | 40 | 41 | class MultiDataset(torch.utils.data.Dataset): 42 | def __init__(self, *datasets): 43 | self.current_epoch = 0 44 | 45 | self.datasets = [] 46 | self.cum_n_samples = [0] 47 | 48 | for dataset in datasets: 49 | self.append(dataset) 50 | 51 | def append(self, dataset): 52 | self.datasets.append(dataset) 53 | self.__update_cum_n_samples(dataset) 54 | 55 | def __update_cum_n_samples(self, dataset): 56 | n_samples = self.cum_n_samples[-1] + len(dataset) 57 | self.cum_n_samples.append(n_samples) 58 | 59 | def dataset_updated(self): 60 | self.cum_n_samples = [0] 61 | for dset in self.datasets: 62 | self.__update_cum_n_samples(dset) 63 | 64 | def __len__(self): 65 | return self.cum_n_samples[-1] 66 | 67 | def __getitem__(self, idx): 68 | didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1 69 | sidx = idx - self.cum_n_samples[didx] 70 | return self.datasets[didx][sidx] 71 | 72 | 73 | 74 | class BaseDataset(torch.utils.data.Dataset): 75 | def __init__(self, train=True, fix_seed_per_epoch=False): 76 | self.current_epoch = 0 77 | self.train = train 78 | self.fix_seed_per_epoch = fix_seed_per_epoch 79 | 80 | def get_rng(self, idx): 81 | rng = np.random.RandomState() 82 | if self.train: 83 | if self.fix_seed_per_epoch: 84 | seed = 1 * len(self) + idx 85 | else: 86 | seed = (self.current_epoch + 1) * len(self) + idx 87 | rng.seed(seed) 88 | else: 89 | rng.seed(idx) 90 | return rng 91 | -------------------------------------------------------------------------------- /data/create_syn_data.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # 5 | # MIT License 6 | # 7 | # Copyright (c) 2019 autonomousvision 8 | # 9 | # Permission is hereby granted, free of charge, to any person obtaining a copy 10 | # of this software and associated documentation files (the "Software"), to deal 11 | # in the Software without restriction, including without limitation the rights 12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | # copies of the Software, and to permit persons to whom the Software is 14 | # furnished to do so, subject to the following conditions: 15 | # 16 | # The above copyright notice and this permission notice shall be included in all 17 | # copies or substantial portions of the Software. 18 | # 19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | # SOFTWARE. 26 | # 27 | # 28 | # MIT License 29 | # 30 | # Copyright, 2021 ams International AG 31 | # 32 | # Permission is hereby granted, free of charge, to any person obtaining a copy 33 | # of this software and associated documentation files (the "Software"), to deal 34 | # in the Software without restriction, including without limitation the rights 35 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 36 | # copies of the Software, and to permit persons to whom the Software is 37 | # furnished to do so, subject to the following conditions: 38 | # 39 | # The above copyright notice and this permission notice shall be included in all 40 | # copies or substantial portions of the Software. 41 | # 42 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 43 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 44 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 45 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 46 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 47 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 48 | # SOFTWARE. 49 | 50 | import argparse 51 | 52 | import numpy as np 53 | import pickle 54 | from pathlib import Path 55 | import time 56 | import json 57 | import cv2 58 | import collections 59 | import sys 60 | import h5py 61 | import os 62 | 63 | sys.path.append('../') 64 | import co 65 | from data_manipulation import get_rotation_matrix, read_pattern_file, post_process 66 | 67 | config_path = os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..', 'config.json')) 68 | with open(config_path) as fp: 69 | config = json.load(fp) 70 | CTD_path = Path(config['CTD_DIR']) 71 | sys.path.append(str(CTD_path / 'data')) 72 | sys.path.append(str(CTD_path / 'renderer')) 73 | 74 | from lcn import lcn 75 | from cyrender import PyRenderInput, PyCamera, PyShader, PyRenderer 76 | 77 | def get_objs(shapenet_dir, obj_classes, num_perclass=100): 78 | shapenet = {'chair': '03001627', 79 | 'airplane': '02691156', 80 | 'car': '02958343', 81 | 'watercraft': '04530566'} 82 | 83 | obj_paths = [] 84 | for cls in obj_classes: 85 | if cls not in shapenet.keys(): 86 | raise Exception('unknown class name') 87 | ids = shapenet[cls] 88 | obj_path = sorted(Path(f'{shapenet_dir}/{ids}').glob('**/models/*.obj')) 89 | obj_paths += obj_path[:num_perclass] 90 | print(f'found {len(obj_paths)} object paths') 91 | 92 | objs = [] 93 | for obj_path in obj_paths: 94 | print(f'load {obj_path}') 95 | v, f, _, n = co.io3d.read_obj(obj_path) 96 | diffs = v.max(axis=0) - v.min(axis=0) 97 | v /= (0.5 * diffs.max()) 98 | v -= (v.min(axis=0) + 1) 99 | f = f.astype(np.int32) 100 | objs.append((v, f, n)) 101 | print(f'loaded {len(objs)} objects') 102 | 103 | return objs 104 | 105 | 106 | def get_mesh(rng, min_z=0): 107 | # set up background board 108 | verts, faces, normals, colors = [], [], [], [] 109 | v, f, n = co.geometry.xyplane(z=0, interleaved=True) 110 | v[:, 2] += -v[:, 2].min() + rng.uniform(3, 5) 111 | v[:, :2] *= 5e2 112 | v[:, 2] = np.mean(v[:, 2]) + (v[:, 2] - np.mean(v[:, 2])) * 5e2 113 | c = np.empty_like(v) 114 | c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32) 115 | verts.append(v) 116 | faces.append(f) 117 | normals.append(n) 118 | colors.append(c) 119 | 120 | # randomly sample 4 foreground objects for each scene 121 | for shape_idx in range(4): 122 | v, f, n = objs[rng.randint(0, len(objs))] 123 | v, f, n = v.copy(), f.copy(), n.copy() 124 | 125 | s = rng.uniform(0.25, 1) 126 | v *= s 127 | R = co.geometry.rotm_from_quat(co.geometry.quat_random(rng=rng)) 128 | v = v @ R.T 129 | n = n @ R.T 130 | v[:, 2] += -v[:, 2].min() + min_z + rng.uniform(0.5, 3) 131 | v[:, :2] += rng.uniform(-1, 1, size=(1, 2)) 132 | 133 | c = np.empty_like(v) 134 | c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32) 135 | 136 | verts.append(v.astype(np.float32)) 137 | faces.append(f) 138 | normals.append(n) 139 | colors.append(c) 140 | 141 | verts, faces = co.geometry.stack_mesh(verts, faces) 142 | normals = np.vstack(normals).astype(np.float32) 143 | colors = np.vstack(colors).astype(np.float32) 144 | return verts, faces, colors, normals 145 | 146 | 147 | def create_data(pattern_type, out_root, idx, n_samples, imsize_proj, imsize, pattern, K_proj, K, K_processed, baseline, blend_im, noise, 148 | track_length=4): 149 | tic = time.time() 150 | rng = np.random.RandomState() 151 | 152 | rng.seed(idx) 153 | 154 | verts, faces, colors, normals = get_mesh(rng) 155 | data = PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy()) 156 | print(f'loading mesh for sample {idx + 1}/{n_samples} took {time.time() - tic}[s]') 157 | 158 | # let the camera point to the center 159 | center = np.array([0, 0, 3], dtype=np.float32) 160 | 161 | basevec = np.array([-baseline, 0, 0], dtype=np.float32) 162 | unit = np.array([0, 0, 1], dtype=np.float32) 163 | 164 | cam_x_ = rng.uniform(-0.2, 0.2) 165 | cam_y_ = rng.uniform(-0.2, 0.2) 166 | cam_z_ = rng.uniform(-0.2, 0.2) 167 | 168 | ret = collections.defaultdict(list) 169 | blend_im_rnd = np.clip(blend_im + rng.uniform(-0.1, 0.1), 0, 1) 170 | 171 | # capture the same static scene from different view points as a track 172 | for ind in range(track_length): 173 | 174 | cam_x = cam_x_ + rng.uniform(-0.1, 0.1) 175 | cam_y = cam_y_ + rng.uniform(-0.1, 0.1) 176 | cam_z = cam_z_ + rng.uniform(-0.1, 0.1) 177 | 178 | tcam = np.array([cam_x, cam_y, cam_z], dtype=np.float32) 179 | 180 | if np.linalg.norm(tcam[0:2]) < 1e-9: 181 | Rcam = np.eye(3, dtype=np.float32) 182 | else: 183 | Rcam = get_rotation_matrix(center, center - tcam) 184 | 185 | tproj = tcam + basevec 186 | Rproj = Rcam 187 | 188 | ret['R'].append(Rcam) 189 | ret['t'].append(tcam) 190 | 191 | fx_proj = K_proj[0, 0] 192 | fy_proj = K_proj[1, 1] 193 | px_proj = K_proj[0, 2] 194 | py_proj = K_proj[1, 2] 195 | im_height_proj = imsize_proj[0] 196 | im_width_proj = imsize_proj[1] 197 | proj = PyCamera(fx_proj, fy_proj, px_proj, py_proj, Rproj, tproj, im_width_proj, im_height_proj) 198 | 199 | fx = K[0, 0] 200 | fy = K[1, 1] 201 | px = K[0, 2] 202 | py = K[1, 2] 203 | im_height = imsize[0] 204 | im_width = imsize[1] 205 | cam = PyCamera(fx, fy, px, py, Rcam, tcam, im_width, im_height) 206 | 207 | shader = PyShader(0.5, 1.5, 0.0, 10) 208 | pyrenderer = PyRenderer(cam, shader, engine='gpu') 209 | if args.pattern_type == 'default': 210 | pyrenderer.mesh_proj(data, proj, pattern.copy(), d_alpha=0, d_beta=0.0) 211 | else: 212 | pyrenderer.mesh_proj(data, proj, pattern.copy(), d_alpha=0, d_beta=0.35) 213 | 214 | # get the reflected laser pattern $R$ 215 | im = pyrenderer.color().copy() 216 | im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY) 217 | 218 | focal_length = K_processed[0, 0] 219 | depth = pyrenderer.depth().copy() 220 | disp = baseline * focal_length / depth 221 | 222 | # get the ambient image $A$ 223 | ambient = pyrenderer.normal().copy() 224 | ambient = cv2.cvtColor(ambient, cv2.COLOR_RGB2GRAY) 225 | 226 | # get the noise free IR image $J$ 227 | im = blend_im_rnd * im + (1 - blend_im_rnd) * ambient 228 | ret['ambient'].append(post_process(pattern_type, ambient)[None].astype(np.float32)) 229 | 230 | # get the gradient magnitude of the ambient image $|\nabla A|$ 231 | ambient = ambient.astype(np.float32) 232 | sobelx = cv2.Sobel(ambient, cv2.CV_32F, 1, 0, ksize=5) 233 | sobely = cv2.Sobel(ambient, cv2.CV_32F, 0, 1, ksize=5) 234 | grad = np.sqrt(sobelx ** 2 + sobely ** 2) 235 | grad = np.maximum(grad - 0.8, 0.0) # parameter 236 | 237 | # get the local contract normalized grad LCN($|\nabla A|$) 238 | grad_lcn, grad_std = lcn.normalize(grad, 5, 0.1) 239 | grad_lcn = np.clip(grad_lcn, 0.0, 1.0) # parameter 240 | ret['grad'].append(post_process(pattern_type, grad_lcn)[None].astype(np.float32)) 241 | 242 | ret['im'].append(post_process(pattern_type, im)[None].astype(np.float32)) 243 | ret['disp'].append(post_process(pattern_type, disp)[None].astype(np.float32)) 244 | 245 | for key in ret.keys(): 246 | ret[key] = np.stack(ret[key], axis=0) 247 | 248 | # save to files 249 | out_dir = out_root / f'{idx:08d}' 250 | out_dir.mkdir(exist_ok=True, parents=True) 251 | out_path = out_dir / f'frames.hdf5' 252 | with h5py.File(out_path, "w") as f: 253 | for k, val in ret.items(): 254 | # f.create_dataset(k, data=val, compression="lzf") 255 | f.create_dataset(k, data=val) 256 | 257 | print(f'create sample {idx + 1}/{n_samples} took {time.time() - tic}[s]') 258 | 259 | 260 | if __name__ == '__main__': 261 | 262 | parser = argparse.ArgumentParser() 263 | parser.add_argument('pattern_type', 264 | help='Select the pattern file for projecting dots', 265 | default='default', 266 | choices=['default', 'kinect', 'real'], type=str) 267 | args = parser.parse_args() 268 | 269 | np.random.seed(42) 270 | 271 | # output directory 272 | config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'config.json')) 273 | with open(config_path) as fp: 274 | config = json.load(fp) 275 | data_root = Path(config['DATA_DIR']) 276 | shapenet_root = config['SHAPENET_DIR'] 277 | 278 | out_root = data_root 279 | out_root.mkdir(parents=True, exist_ok=True) 280 | 281 | # load shapenet models 282 | obj_classes = ['chair'] 283 | objs = get_objs(shapenet_root, obj_classes) 284 | 285 | # camera parameters 286 | if args.pattern_type == 'real': 287 | fl_proj = 1112.1806640625 288 | fl = 1112.1806640625 289 | imsize_proj = (1280, 1080) 290 | imsize = (1280, 1080) 291 | imsize_processed = (512, 432) 292 | K_proj = np.array([[fl_proj, 0, 517.0896606445312], [0, fl_proj, 649.6329956054688], [0, 0, 1]], dtype=np.float32) 293 | K = np.array([[fl, 0, 517.0896606445312], [0, fl, 649.6329956054688], [0, 0, 1]], dtype=np.float32) 294 | baseline = 0.0246 295 | blend_im = 0.6 296 | noise = 0 297 | else: 298 | fl_proj = 1582.06005876 299 | fl = 435.2 300 | imsize_proj = (4096, 4096) 301 | imsize = (512, 432) 302 | imsize_processed = (512, 432) 303 | K_proj = np.array([[fl_proj, 0, 2047.5], [0, fl_proj, 2047.5], [0, 0, 1]], dtype=np.float32) 304 | K = np.array([[fl, 0, 216], [0, fl, 256], [0, 0, 1]], dtype=np.float32) 305 | baseline = 0.025 306 | blend_im = 0.6 307 | noise = 0 308 | 309 | # capture the same static scene from different view points as a track 310 | track_length = 4 311 | 312 | # load pattern image 313 | pattern = read_pattern_file(args.pattern_type, imsize_proj) 314 | 315 | x_cam = np.arange(0, imsize[1]) 316 | y_cam = np.arange(0, imsize[0]) 317 | x_mesh, y_mesh = np.meshgrid(x_cam, y_cam) 318 | x_mesh_f = np.reshape(x_mesh, [-1]) 319 | y_mesh_f = np.reshape(y_mesh, [-1]) 320 | 321 | grid_points = np.stack([x_mesh_f, y_mesh_f, np.ones_like(x_mesh_f)], axis=0) 322 | grid_points_mapped = K_proj.dot(np.linalg.inv(K).dot(grid_points)) 323 | grid_points_mapped = grid_points_mapped / grid_points_mapped[2, :] 324 | 325 | x_map = np.reshape(grid_points_mapped[0, :], x_mesh.shape) 326 | y_map = np.reshape(grid_points_mapped[1, :], y_mesh.shape) 327 | x_map, y_map = x_map.astype('float32'), y_map.astype('float32') 328 | mapped_pattern = cv2.remap(pattern, x_map, y_map, cv2.INTER_LINEAR) 329 | 330 | pattern_processed, K_processed = post_process(args.pattern_type, mapped_pattern, K) 331 | # write settings to file 332 | settings = { 333 | 'imsize': imsize_processed, 334 | 'pattern': pattern_processed, 335 | 'baseline': baseline, 336 | 'K': K_processed, 337 | } 338 | out_path = out_root / f'settings.pkl' 339 | print(f'write settings to {out_path}') 340 | with open(str(out_path), 'wb') as f: 341 | pickle.dump(settings, f, pickle.HIGHEST_PROTOCOL) 342 | 343 | # start the job 344 | n_samples = 2 ** 10 + 2 ** 13 345 | # n_samples = 2048 346 | for idx in range(n_samples): 347 | parameters = ( 348 | args.pattern_type, out_root, idx, n_samples, imsize_proj, imsize, pattern, K_proj, K, K_processed, baseline, blend_im, noise, track_length) 349 | create_data(*parameters) -------------------------------------------------------------------------------- /data/data_manipulation.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # 5 | # MIT License 6 | # 7 | # Copyright (c) 2019 autonomousvision 8 | # 9 | # Permission is hereby granted, free of charge, to any person obtaining a copy 10 | # of this software and associated documentation files (the "Software"), to deal 11 | # in the Software without restriction, including without limitation the rights 12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | # copies of the Software, and to permit persons to whom the Software is 14 | # furnished to do so, subject to the following conditions: 15 | # 16 | # The above copyright notice and this permission notice shall be included in all 17 | # copies or substantial portions of the Software. 18 | # 19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | # SOFTWARE. 26 | # 27 | # 28 | # MIT License 29 | # 30 | # Copyright, 2021 ams International AG 31 | # 32 | # Permission is hereby granted, free of charge, to any person obtaining a copy 33 | # of this software and associated documentation files (the "Software"), to deal 34 | # in the Software without restriction, including without limitation the rights 35 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 36 | # copies of the Software, and to permit persons to whom the Software is 37 | # furnished to do so, subject to the following conditions: 38 | # 39 | # The above copyright notice and this permission notice shall be included in all 40 | # copies or substantial portions of the Software. 41 | # 42 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 43 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 44 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 45 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 46 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 47 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 48 | # SOFTWARE. 49 | 50 | import numpy as np 51 | import cv2 52 | 53 | def read_pattern_file(pattern_type, pattern_size): 54 | if pattern_type == 'default': 55 | pattern_path = 'default_pattern.png' 56 | elif pattern_type == 'kinect': 57 | pattern_path = 'kinect_pattern.png' 58 | elif pattern_type == 'real': 59 | pattern_path = 'real_pattern.png' 60 | 61 | pattern = cv2.imread(pattern_path) 62 | pattern = pattern.astype(np.float32) 63 | pattern /= 255 64 | 65 | if pattern.ndim == 2: 66 | pattern = np.stack([pattern for idx in range(3)], axis=2) 67 | 68 | if pattern_type == 'default': 69 | pattern = np.rot90(np.flip(pattern, axis=1)) 70 | elif pattern_type == 'kinect': 71 | min_dim = min(pattern.shape[0:2]) 72 | start_h = (pattern.shape[0] - min_dim) // 2 73 | start_w = (pattern.shape[1] - min_dim) // 2 74 | pattern = pattern[start_h:start_h + min_dim, start_w:start_w + min_dim] 75 | pattern = cv2.resize(pattern, pattern_size, interpolation=cv2.INTER_LINEAR) 76 | 77 | return pattern 78 | 79 | def get_rotation_matrix(v0, v1): 80 | v0 = v0/np.linalg.norm(v0) 81 | v1 = v1/np.linalg.norm(v1) 82 | v = np.cross(v0,v1) 83 | c = np.dot(v0,v1) 84 | s = np.linalg.norm(v) 85 | I = np.eye(3) 86 | vXStr = '{} {} {}; {} {} {}; {} {} {}'.format(0, -v[2], v[1], v[2], 0, -v[0], -v[1], v[0], 0) 87 | k = np.matrix(vXStr) 88 | r = I + k + k @ k * ((1 -c)/(s**2)) 89 | return np.asarray(r.astype(np.float32)) 90 | 91 | def post_process(pattern_type, im, K = None): 92 | if pattern_type == 'real': 93 | im_processed = im[128:-128, 108:-108, ...].copy() 94 | im_processed = cv2.resize(im_processed, (432, 512), interpolation=cv2.INTER_LINEAR) 95 | 96 | if K is not None: 97 | K_processed = K.copy() 98 | 99 | K_processed[0, 0] = K_processed[0, 0] / 2 100 | K_processed[1, 1] = K_processed[1, 1] / 2 101 | 102 | K_processed[0, 2] = (K_processed[0, 2] - 108) / 2 103 | K_processed[1, 2] = (K_processed[1, 2] - 128) / 2 104 | 105 | return im_processed, K_processed 106 | else: 107 | return im_processed 108 | else: 109 | if K is not None: 110 | return im, K 111 | else: 112 | return im 113 | 114 | def augment_image(img,rng, amb=None, disp=None, primary_disp= None, sgm_disp= None, grad=None,max_shift=64,max_blur=1.5,max_noise=10.0,max_sp_noise=0.001): 115 | 116 | # get min/max values of image 117 | min_val = np.min(img) 118 | max_val = np.max(img) 119 | 120 | # init augmented image 121 | img_aug = img 122 | amb_aug = amb 123 | 124 | # init disparity correction map 125 | disp_aug = disp 126 | primary_disp_aug = primary_disp 127 | sgm_disp_aug = sgm_disp 128 | grad_aug = grad 129 | 130 | # apply affine transformation 131 | if max_shift>1: 132 | 133 | # affine parameters 134 | rows,cols = img.shape 135 | shear = 0 136 | shift = 0 137 | shear_correction = 0 138 | if rng.uniform(0,1)<0.75: shear = rng.uniform(-max_shift,max_shift) # shear with 75% probability 139 | else: shift = rng.uniform(-max_shift/2,max_shift) # shift with 25% probability 140 | if shear<0: shear_correction = -shear 141 | 142 | # affine transformation 143 | a = shear/float(rows) 144 | b = shift+shear_correction 145 | 146 | # warp image 147 | T = np.float32([[1,a,b],[0,1,0]]) 148 | img_aug = cv2.warpAffine(img_aug,T,(cols,rows)) 149 | if amb is not None: 150 | amb_aug = cv2.warpAffine(amb_aug,T,(cols,rows)) 151 | if grad is not None: 152 | grad_aug = cv2.warpAffine(grad,T,(cols,rows)) 153 | 154 | # disparity correction map 155 | col = a*np.array(range(rows))+b 156 | disp_delta = np.tile(col,(cols,1)).transpose() 157 | if disp is not None: 158 | disp_aug = cv2.warpAffine(disp+disp_delta,T,(cols,rows)) 159 | if primary_disp is not None: 160 | primary_disp_aug = cv2.warpAffine(primary_disp+disp_delta,T,(cols,rows)) 161 | if sgm_disp is not None: 162 | sgm_disp_aug = cv2.warpAffine(sgm_disp+disp_delta,T,(cols,rows)) 163 | 164 | # gaussian smoothing 165 | if rng.uniform(0,1)<0.5: 166 | img_aug = cv2.GaussianBlur(img_aug,(5,5),rng.uniform(0.2,max_blur)) 167 | if amb is not None: 168 | amb_aug = cv2.GaussianBlur(amb_aug,(5,5),rng.uniform(0.2,max_blur)) 169 | 170 | # per-pixel gaussian noise 171 | img_aug = img_aug + rng.randn(*img_aug.shape)*rng.uniform(0.0,max_noise)/255.0 172 | if amb is not None: 173 | amb_aug = amb_aug + rng.randn(*amb_aug.shape)*rng.uniform(0.0,max_noise)/255.0 174 | 175 | # salt-and-pepper noise 176 | if rng.uniform(0,1)<0.5: 177 | ratio=rng.uniform(0.0,max_sp_noise) 178 | img_shape = img_aug.shape 179 | img_aug = img_aug.flatten() 180 | coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio)) 181 | img_aug[coord] = max_val 182 | coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio)) 183 | img_aug[coord] = min_val 184 | img_aug = np.reshape(img_aug, img_shape) 185 | 186 | # clip intensities back to [0,1] 187 | img_aug = np.maximum(img_aug,0.0) 188 | img_aug = np.minimum(img_aug,1.0) 189 | 190 | if amb is not None: 191 | amb_aug = np.maximum(amb_aug,0.0) 192 | amb_aug = np.minimum(amb_aug,1.0) 193 | 194 | # return image 195 | return img_aug, amb_aug, disp_aug, primary_disp_aug, sgm_disp_aug, grad_aug 196 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # MIT License 5 | # 6 | # Copyright (c) 2019 autonomousvision 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | import h5py 27 | import numpy as np 28 | import pickle 29 | import cv2 30 | import os 31 | 32 | from . import base_dataset 33 | from .data_manipulation import augment_image 34 | 35 | 36 | class TrackSynDataset(base_dataset.BaseDataset): 37 | ''' 38 | Load locally saved synthetic dataset 39 | ''' 40 | def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False, load_flow_data = False, load_primary_data = False, load_pseudo_gt = False, data_type = 'synthetic'): 41 | super().__init__(train=train) 42 | 43 | self.settings_path = settings_path 44 | self.sample_paths = sample_paths 45 | self.data_aug = data_aug 46 | self.train = train 47 | self.track_length=track_length 48 | self.load_flow_data = load_flow_data 49 | self.load_primary_data = load_primary_data 50 | self.load_pseudo_gt = load_pseudo_gt 51 | self.data_type = data_type 52 | assert(track_length<=4) 53 | 54 | with open(str(settings_path), 'rb') as f: 55 | settings = pickle.load(f) 56 | self.imsizes = [(settings['imsize'][0] // (2 ** s), settings['imsize'][1] // (2 ** s)) for s in range(4)] 57 | self.patterns = [] 58 | for imsize in self.imsizes: 59 | pat = cv2.resize(settings['pattern'], (imsize[1], imsize[0]), interpolation=cv2.INTER_LINEAR) 60 | self.patterns.append(pat) 61 | self.baseline = settings['baseline'] 62 | self.K = settings['K'] 63 | self.focal_lengths = [self.K[0,0]/(2**s) for s in range(4)] 64 | 65 | self.scale = len(self.imsizes) 66 | 67 | self.max_shift = 0 68 | self.max_blur = 0.5 69 | self.max_noise = 3.0 70 | self.max_sp_noise = 0.0005 71 | 72 | def __len__(self): 73 | return len(self.sample_paths) 74 | 75 | def __getitem__(self, idx): 76 | if not self.train: 77 | rng = self.get_rng(idx) 78 | else: 79 | rng = np.random.RandomState() 80 | sample_path = self.sample_paths[idx] 81 | 82 | if self.train: 83 | track_ind = np.random.permutation(4)[0:self.track_length] 84 | else: 85 | track_ind = np.arange(0, self.track_length) 86 | 87 | ret = {} 88 | ret['id'] = idx 89 | 90 | with h5py.File(os.path.join(sample_path,f'frames.hdf5'), "r") as sample_frames: 91 | # load imgs, at all scales 92 | for sidx in range(len(self.imsizes)): 93 | im = sample_frames['im'] 94 | amb = sample_frames['ambient'] 95 | grad = sample_frames['grad'] 96 | if sidx == 0: 97 | ret[f'im{sidx}'] = np.stack([im[tidx, ...] for tidx in track_ind], axis=0) 98 | ret[f'ambient{sidx}'] = np.stack([amb[tidx, ...] for tidx in track_ind], axis=0) 99 | ret[f'grad{sidx}'] = np.stack([grad[tidx, ...] for tidx in track_ind], axis=0) 100 | else: 101 | ret[f'im{sidx}'] = np.stack([cv2.resize(im[tidx, 0, ...], self.imsizes[sidx][::-1])[None] for tidx in track_ind], axis=0) 102 | ret[f'ambient{sidx}'] = np.stack([cv2.resize(amb[tidx, 0, ...], self.imsizes[sidx][::-1])[None] for tidx in track_ind], axis=0) 103 | ret[f'grad{sidx}'] = np.stack([cv2.resize(grad[tidx, 0, ...], self.imsizes[sidx][::-1])[None] for tidx in track_ind], axis=0) 104 | 105 | # load disp and grad only at full resolution 106 | ret[f'disp0'] = np.stack([sample_frames['disp'][tidx, ...] for tidx in track_ind], axis=0) 107 | ret['R'] = np.stack([sample_frames['R'][tidx, ...] for tidx in track_ind], axis=0) 108 | ret['t'] = np.stack([sample_frames['t'][tidx, ...] for tidx in track_ind], axis=0) 109 | if self.data_type == 'real': 110 | ret[f'sgm_disp'] = np.stack([sample_frames['sgm_disp'][tidx, ...] for tidx in track_ind], axis=0) 111 | 112 | if self.load_flow_data: 113 | with h5py.File(os.path.join(sample_path, f'flow.hdf5'), "r") as sample_flow: 114 | for i0, tidx0 in enumerate(track_ind): 115 | for i1, tidx1 in enumerate(track_ind): 116 | if tidx0 != tidx1: 117 | ret[f'flow_{i0}{i1}'] = sample_flow[f'flow_{tidx0}{tidx1}'][:] 118 | 119 | if self.load_primary_data: 120 | with h5py.File(os.path.join(sample_path, f'single_frame_disp.hdf5'), "r") as sample_primary_disp: 121 | ret[f'primary_disp'] = np.stack([sample_primary_disp['disp'][tidx, ...] for tidx in track_ind], axis=0) 122 | 123 | if self.load_pseudo_gt: 124 | with h5py.File(os.path.join(sample_path, f'multi_frame_disp.hdf5'), "r") as sample_disp: 125 | ret[f'pseudo_gt'] = np.stack([sample_disp['disp'][tidx, ...] for tidx in track_ind], axis=0) 126 | 127 | #### apply data augmentation at different scales seperately, only work for max_shift=0 128 | if self.data_aug: 129 | for sidx in range(len(self.imsizes)): 130 | if sidx==0: 131 | img = ret[f'im{sidx}'] 132 | amb = ret[f'ambient{sidx}'] 133 | disp = ret[f'disp{sidx}'] 134 | if self.load_primary_data: 135 | primary_disp = ret[f'primary_disp'] 136 | else: 137 | primary_disp = None 138 | if self.data_type == 'real': 139 | sgm_disp = ret[f'sgm_disp'] 140 | else: 141 | sgm_disp = None 142 | grad = ret[f'grad{sidx}'] 143 | img_aug = np.zeros_like(img) 144 | amb_aug = np.zeros_like(img) 145 | disp_aug = np.zeros_like(img) 146 | primary_disp_aug = np.zeros_like(img) 147 | sgm_disp_aug = np.zeros_like(img) 148 | grad_aug = np.zeros_like(img) 149 | for i in range(img.shape[0]): 150 | if self.load_primary_data: 151 | primary_disp_i = primary_disp[i,0] 152 | else: 153 | primary_disp_i = None 154 | if self.data_type == 'real': 155 | sgm_disp_i = sgm_disp[i,0] 156 | else: 157 | sgm_disp_i = None 158 | img_aug_, amb_aug_, disp_aug_, primary_disp_aug_, sgm_disp_aug_, grad_aug_ = augment_image(img[i,0],rng, 159 | amb=amb[i,0],disp=disp[i,0],primary_disp=primary_disp_i, sgm_disp= sgm_disp_i,grad=grad[i,0], 160 | max_shift=self.max_shift, max_blur=self.max_blur, 161 | max_noise=self.max_noise, max_sp_noise=self.max_sp_noise) 162 | img_aug[i] = img_aug_[None].astype(np.float32) 163 | amb_aug[i] = amb_aug_[None].astype(np.float32) 164 | disp_aug[i] = disp_aug_[None].astype(np.float32) 165 | if self.load_primary_data: 166 | primary_disp_aug[i] = primary_disp_aug_[None].astype(np.float32) 167 | if self.data_type == 'real': 168 | sgm_disp_aug[i] = sgm_disp_aug_[None].astype(np.float32) 169 | grad_aug[i] = grad_aug_[None].astype(np.float32) 170 | ret[f'im{sidx}'] = img_aug 171 | ret[f'ambient{sidx}'] = amb_aug 172 | ret[f'disp{sidx}'] = disp_aug 173 | if self.load_primary_data: 174 | ret[f'primary_disp'] = primary_disp_aug 175 | if self.data_type == 'real': 176 | ret[f'sgm_disp'] = sgm_disp_aug 177 | ret[f'grad{sidx}'] = grad_aug 178 | else: 179 | img = ret[f'im{sidx}'] 180 | img_aug = np.zeros_like(img) 181 | for i in range(img.shape[0]): 182 | img_aug_, _, _, _, _, _ = augment_image(img[i,0],rng, 183 | max_shift=self.max_shift, max_blur=self.max_blur, 184 | max_noise=self.max_noise, max_sp_noise=self.max_sp_noise) 185 | img_aug[i] = img_aug_[None].astype(np.float32) 186 | ret[f'im{sidx}'] = img_aug 187 | 188 | return ret 189 | 190 | def getK(self, sidx=0): 191 | K = self.K.copy() / (2**sidx) 192 | K[2,2] = 1 193 | return K 194 | 195 | 196 | 197 | if __name__ == '__main__': 198 | pass 199 | 200 | -------------------------------------------------------------------------------- /data/default_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/DepthInSpace/fe759807f82df4c48c16b97f061718175ea0e6e9/data/default_pattern.png -------------------------------------------------------------------------------- /data/kinect_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/DepthInSpace/fe759807f82df4c48c16b97f061718175ea0e6e9/data/kinect_pattern.png -------------------------------------------------------------------------------- /data/presave_disp.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # MIT License 5 | # 6 | # Copyright, 2021 ams International AG 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | import numpy as np 27 | from pathlib import Path 28 | import os 29 | import sys 30 | import argparse 31 | import json 32 | import pickle 33 | import h5py 34 | import torch 35 | from tqdm import tqdm 36 | 37 | sys.path.append('../') 38 | from model import networks 39 | from model import multi_frame_networks 40 | 41 | if __name__ == '__main__': 42 | 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('architecture', 45 | help='Select the architecture to produce disparity (single_frame/multi_frame)', 46 | choices=['single_frame', 'multi_frame'], type=str) 47 | parser.add_argument('--epoch', 48 | help='Epoch whose results will be pre-saved', 49 | default=-1, type=int) 50 | args = parser.parse_args() 51 | 52 | # output directory 53 | config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'config.json')) 54 | with open(config_path) as fp: 55 | config = json.load(fp) 56 | data_root = Path(config['DATA_DIR']) 57 | output_dir = Path(config['OUTPUT_DIR']) 58 | 59 | model_path = output_dir / args.architecture / f'net_{args.epoch:04d}.params' 60 | settings_path = data_root / 'settings.pkl' 61 | with open(str(settings_path), 'rb') as f: 62 | settings = pickle.load(f) 63 | imsizes = [(settings['imsize'][0] // (2 ** s), settings['imsize'][1] // (2 ** s)) for s in range(4)] 64 | K = settings['K'] 65 | Ki = np.linalg.inv(K) 66 | baseline = settings['baseline'] 67 | pat = settings['pattern'] 68 | 69 | d2d = networks.DispToDepth(focal_length= K[0, 0],baseline= baseline).cuda() 70 | lcn_in = networks.LCN(5, 0.05).cuda() 71 | 72 | pat = pat.mean(axis=2) 73 | pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda') 74 | pat_lcn, _ = lcn_in(pat) 75 | pat_cat = torch.cat((pat_lcn, pat), dim=1) 76 | 77 | if args.architecture == 'single_frame': 78 | net = networks.DispDecoder(channels_in=2, max_disp=128, imsizes=imsizes).cuda().eval() 79 | elif args.architecture == 'multi_frame': 80 | net = multi_frame_networks.FuseNet(imsize=imsizes[0], K=K, baseline=baseline, max_disp=128).cuda().eval() 81 | 82 | net.load_state_dict(torch.load(model_path)) 83 | 84 | with torch.no_grad(): 85 | sample_list = [os.path.join(data_root, o) for o in os.listdir(data_root) if os.path.isdir(os.path.join(data_root,o))] 86 | for sample_path in tqdm(sample_list, ascii = True): 87 | with h5py.File(os.path.join(sample_path, f'frames.hdf5'), "r") as frames_sample: 88 | if args.architecture == 'single_frame': 89 | im = torch.tensor(frames_sample['im']).cuda() 90 | im_lcn, im_std = lcn_in(im) 91 | im = torch.cat([im_lcn, im], dim=1) 92 | disp = net(im)[0].cpu().numpy() 93 | elif args.architecture == 'multi_frame': 94 | im = torch.tensor(frames_sample['im']).cuda() 95 | im_lcn, im_std = lcn_in(im) 96 | im = torch.cat([im_lcn, im], dim=1) 97 | 98 | amb = torch.tensor(frames_sample['ambient']).cuda() 99 | 100 | R = torch.tensor(frames_sample['R']).cuda() 101 | t = torch.tensor(frames_sample['t']).cuda() 102 | 103 | flow = {} 104 | with h5py.File(os.path.join(sample_path, f'flow.hdf5'), "r") as sample_flow: 105 | for i0 in range(4): 106 | for i1 in range(4): 107 | if i0 != i1: 108 | flow[f'flow_{i0}{i1}'] = torch.tensor(sample_flow[f'flow_{i0}{i1}'][:]).cuda() 109 | 110 | with h5py.File(os.path.join(sample_path, f'single_frame_disp.hdf5'), "r") as sample_primary_disp: 111 | primary_disp = torch.tensor(sample_primary_disp['disp'][:]).cuda() 112 | 113 | disp = net(im.unsqueeze(1), amb.unsqueeze(1), primary_disp.unsqueeze(1), d2d(primary_disp.unsqueeze(1)), R.unsqueeze(1), t.unsqueeze(1), flow).detach().cpu().numpy() 114 | disp = disp[:, 0] 115 | 116 | with h5py.File(os.path.join(sample_path, f'{args.architecture}_disp.hdf5'), "w") as f: 117 | f.create_dataset('disp', data=disp) -------------------------------------------------------------------------------- /data/presave_optical_flow_data.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # MIT License 5 | # 6 | # Copyright, 2021 ams International AG 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | from pathlib import Path 27 | import os 28 | import json 29 | import pickle 30 | 31 | if __name__ == '__main__': 32 | 33 | # output directory 34 | config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'config.json')) 35 | with open(config_path) as fp: 36 | config = json.load(fp) 37 | data_root = Path(config['DATA_DIR']) 38 | liteflownet_path = Path(config['LITEFLOWNET_DIR']) 39 | 40 | script_path = str(liteflownet_path / 'run.py') 41 | data_path = str(data_root) 42 | os.environ['PYTHONPATH'] = str(liteflownet_path) + os.pathsep + str(liteflownet_path / 'correlation') 43 | os.system('python ' + script_path + ' --model default' + ' --data_path ' + data_path) 44 | -------------------------------------------------------------------------------- /data/real_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/DepthInSpace/fe759807f82df4c48c16b97f061718175ea0e6e9/data/real_pattern.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # Copyright (c) 2021 ams International AG 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 7 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation 8 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 9 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 14 | # OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 15 | # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 16 | # IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /model/ext_functions.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # MIT License 5 | # 6 | # Copyright (c) 2019 autonomousvision 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | import torch 27 | import sys 28 | import json 29 | from pathlib import Path 30 | import os 31 | 32 | config_path = os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..', 'config.json')) 33 | with open(config_path) as fp: 34 | config = json.load(fp) 35 | CTD_path = Path(config['CTD_DIR']) 36 | sys.path.append(str(CTD_path / 'torchext')) 37 | 38 | import ext_cpu 39 | import ext_cuda 40 | 41 | class NNFunction(torch.autograd.Function): 42 | @staticmethod 43 | def forward(ctx, in0, in1): 44 | args = (in0, in1) 45 | if in0.is_cuda: 46 | out = ext_cuda.nn_cuda(*args) 47 | else: 48 | out = ext_cpu.nn_cpu(*args) 49 | return out 50 | 51 | @staticmethod 52 | def backward(ctx, grad_out): 53 | return None, None 54 | 55 | def nn(in0, in1): 56 | return NNFunction.apply(in0, in1) 57 | 58 | 59 | class CrossCheckFunction(torch.autograd.Function): 60 | @staticmethod 61 | def forward(ctx, in0, in1): 62 | args = (in0, in1) 63 | if in0.is_cuda: 64 | out = ext_cuda.crosscheck_cuda(*args) 65 | else: 66 | out = ext_cpu.crosscheck_cpu(*args) 67 | return out 68 | 69 | @staticmethod 70 | def backward(ctx, grad_out): 71 | return None, None 72 | 73 | def crosscheck(in0, in1): 74 | return CrossCheckFunction.apply(in0, in1) 75 | 76 | class ProjNNFunction(torch.autograd.Function): 77 | @staticmethod 78 | def forward(ctx, xyz0, xyz1, K, patch_size): 79 | args = (xyz0, xyz1, K, patch_size) 80 | if xyz0.is_cuda: 81 | out = ext_cuda.proj_nn_cuda(*args) 82 | else: 83 | out = ext_cpu.proj_nn_cpu(*args) 84 | return out 85 | 86 | @staticmethod 87 | def backward(ctx, grad_out): 88 | return None, None, None, None 89 | 90 | def proj_nn(xyz0, xyz1, K, patch_size): 91 | return ProjNNFunction.apply(xyz0, xyz1, K, patch_size) 92 | 93 | 94 | 95 | class XCorrVolFunction(torch.autograd.Function): 96 | @staticmethod 97 | def forward(ctx, in0, in1, n_disps, block_size): 98 | args = (in0, in1, n_disps, block_size) 99 | if in0.is_cuda: 100 | out = ext_cuda.xcorrvol_cuda(*args) 101 | else: 102 | out = ext_cpu.xcorrvol_cpu(*args) 103 | return out 104 | 105 | @staticmethod 106 | def backward(ctx, grad_out): 107 | return None, None, None, None 108 | 109 | def xcorrvol(in0, in1, n_disps, block_size): 110 | return XCorrVolFunction.apply(in0, in1, n_disps, block_size) 111 | 112 | 113 | 114 | 115 | class PhotometricLossFunction(torch.autograd.Function): 116 | @staticmethod 117 | def forward(ctx, es, ta, block_size, type, eps): 118 | args = (es, ta, block_size, type, eps) 119 | ctx.save_for_backward(es, ta) 120 | ctx.block_size = block_size 121 | ctx.type = type 122 | ctx.eps = eps 123 | if es.is_cuda: 124 | out = ext_cuda.photometric_loss_forward(*args) 125 | else: 126 | out = ext_cpu.photometric_loss_forward(*args) 127 | return out 128 | 129 | @staticmethod 130 | def backward(ctx, grad_out): 131 | es, ta = ctx.saved_tensors 132 | block_size = ctx.block_size 133 | type = ctx.type 134 | eps = ctx.eps 135 | args = (es, ta, grad_out.contiguous(), block_size, type, eps) 136 | if grad_out.is_cuda: 137 | grad_es = ext_cuda.photometric_loss_backward(*args) 138 | else: 139 | grad_es = ext_cpu.photometric_loss_backward(*args) 140 | return grad_es, None, None, None, None 141 | 142 | def photometric_loss(es, ta, block_size, type='mse', eps=0.1): 143 | type = type.lower() 144 | if type == 'mse': 145 | type = 0 146 | elif type == 'sad': 147 | type = 1 148 | elif type == 'census_mse': 149 | type = 2 150 | elif type == 'census_sad': 151 | type = 3 152 | else: 153 | raise Exception('invalid loss type') 154 | return PhotometricLossFunction.apply(es, ta, block_size, type, eps) 155 | 156 | def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1): 157 | type = type.lower() 158 | p = block_size // 2 159 | es_pad = torch.nn.functional.pad(es, (p,p,p,p), mode='replicate') 160 | ta_pad = torch.nn.functional.pad(ta, (p,p,p,p), mode='replicate') 161 | es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size) 162 | ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size) 163 | es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3]) 164 | ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3]) 165 | if type == 'mse': 166 | ref = (es_uf - ta_uf)**2 167 | elif type == 'sad': 168 | ref = torch.abs(es_uf - ta_uf) 169 | elif type == 'census_mse' or type == 'census_sad': 170 | des = es_uf - es.unsqueeze(2) 171 | dta = ta_uf - ta.unsqueeze(2) 172 | h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps)) 173 | h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps)) 174 | diff = h_des - h_dta 175 | if type == 'census_mse': 176 | ref = diff * diff 177 | elif type == 'census_sad': 178 | ref = torch.abs(diff) 179 | else: 180 | raise Exception('invalid loss type') 181 | ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3]) 182 | ref = torch.sum(ref, dim=1, keepdim=True) / block_size**2 183 | return ref -------------------------------------------------------------------------------- /model/multi_frame_networks.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # MIT License 5 | # 6 | # Copyright, 2021 ams International AG 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | import cv2 27 | 28 | import torch 29 | import numpy as np 30 | 31 | from torch.utils.checkpoint import checkpoint 32 | 33 | from .networks import TimedModule, OutputLayerFactory 34 | 35 | 36 | def merge_tl_bs(x): 37 | return x.contiguous().view(-1, *x.shape[2:]) 38 | 39 | def split_tl_bs(x, tl, bs): 40 | return x.contiguous().view(tl, bs, *x.shape[1:]) 41 | 42 | def resize_like(x, target): 43 | x_shape = x.shape[:-3] 44 | height = target.shape[-2] 45 | width = target.shape[-1] 46 | 47 | out = torch.nn.functional.interpolate(x.contiguous().view(-1, *x.shape[-3:]), size=(height, width), mode='bilinear', 48 | align_corners=True) 49 | out = out.view(*x_shape, *out.shape[1:]) 50 | 51 | return out 52 | 53 | 54 | def resize_flow_like(flow, target): 55 | height = target.shape[-2] 56 | width = target.shape[-1] 57 | 58 | out = {} 59 | for key, val in flow.items(): 60 | flow_height = val.shape[-2] 61 | flow_width = val.shape[-1] 62 | 63 | resized_flow = torch.nn.functional.interpolate(val, size=(height, width), mode='bilinear', align_corners=True) 64 | resized_flow[:, 0, :, :] *= float(width) / float(flow_width) 65 | resized_flow[:, 1, :, :] *= float(height) / float(flow_height) 66 | out[key] = resized_flow 67 | 68 | return out 69 | 70 | def resize_flow_masks_like(flow_masks, target): 71 | height = target.shape[-2] 72 | width = target.shape[-1] 73 | 74 | with torch.no_grad(): 75 | out = {} 76 | for key, val in flow_masks.items(): 77 | resized_mask = torch.nn.functional.interpolate(val, size=(height, width), mode='bilinear', align_corners=True) 78 | 79 | out[key] = (resized_mask > 0.5).float() 80 | 81 | return out 82 | 83 | def warp(x, flow): 84 | width = x.shape[-1] 85 | height = x.shape[-2] 86 | 87 | u, v = np.meshgrid(range(width), range(height)) 88 | u = torch.from_numpy(u.astype('float32')).to(x.device) 89 | v = torch.from_numpy(v.astype('float32')).to(x.device) 90 | 91 | uv_prj = flow.clone().permute(0, 2, 3, 1) 92 | uv_prj[..., 0] += u 93 | uv_prj[..., 1] += v 94 | 95 | uv_prj[..., 0] = 2 * (uv_prj[..., 0] / (width - 1) - 0.5) 96 | uv_prj[..., 1] = 2 * (uv_prj[..., 1] / (height - 1) - 0.5) 97 | x_prj = torch.nn.functional.grid_sample(x, uv_prj, padding_mode='zeros', align_corners=True) 98 | 99 | return x_prj 100 | 101 | class FuseNet(TimedModule): 102 | ''' 103 | Fuse Net 104 | ''' 105 | def __init__(self, imsize, K, baseline, track_length = 4, block_num = 4, channels = 32, max_disp= 128, movement_mask_en = 1): 106 | super(FuseNet, self).__init__(mod_name='FuseNet') 107 | self.movement_mask_en = movement_mask_en 108 | 109 | self.im_height = imsize[0] 110 | self.im_width = imsize[1] 111 | self.core_height = self.im_height // 2 112 | self.core_width = self.im_width // 2 113 | self.K = K 114 | self.Ki = np.linalg.inv(K) 115 | self.baseline = baseline 116 | self.track_length = track_length 117 | self.block_num = block_num 118 | self.channels = channels 119 | self.max_disp = max_disp 120 | 121 | u, v = np.meshgrid(range(self.im_width), range(self.im_height)) 122 | u = cv2.resize(u, (self.core_width, self.core_height), interpolation = cv2.INTER_NEAREST) 123 | v = cv2.resize(v, (self.core_width, self.core_height), interpolation = cv2.INTER_NEAREST) 124 | uv = np.stack((u,v,np.ones_like(u)), axis=2).reshape(-1,3) 125 | 126 | ray = uv @ self.Ki.T 127 | ray = ray.reshape(1, 1,-1,3).astype(np.float32) 128 | self.ray = torch.from_numpy(ray).cuda() 129 | 130 | self.conv1 = self.conv(4, self.channels // 2, kernel_size=4, stride=2) 131 | self.conv2 = self.conv(self.channels // 2, self.channels // 2, kernel_size=3, stride=1) 132 | 133 | # self.conv3 = self.conv(self.channels // 2, self.channels, kernel_size=4, stride=2) 134 | self.conv3 = self.conv(self.channels // 2, self.channels, kernel_size=3, stride=1) 135 | self.conv4 = self.conv(self.channels, self.channels, kernel_size=3, stride=1) 136 | 137 | self.res1 = ResNetBlock(self.channels) 138 | self.res2 = ResNetBlock(self.channels) 139 | self.res3 = ResNetBlock(self.channels) 140 | 141 | self.blocks = torch.nn.ModuleList([Block2D3D(channels = self.channels, tl= self.track_length) for i in range(self.block_num)]) 142 | 143 | self.upconv1 = self.upconv(self.channels, self.channels) 144 | self.upconv2 = self.upconv(self.channels, self.channels) 145 | 146 | self.amb_conv = self.conv(1, 16, kernel_size=3, stride=1) 147 | self.amb_res1 = ResNetBlock(16) 148 | self.amb_res2 = ResNetBlock(16) 149 | 150 | self.ref_conv = self.conv(16 + self.channels, 32, kernel_size=3, stride=1) 151 | self.ref_res1 = ResNetBlock(32) 152 | self.ref_res2 = ResNetBlock(32) 153 | self.ref_res3 = ResNetBlock(32) 154 | 155 | self.final_conv = self.conv(32, 16, kernel_size=3, stride=1) 156 | 157 | self.predict_disp = OutputLayerFactory( type='disp', params={ 'alpha': self.max_disp, 'beta': 0, 'gamma': 1, 'offset': 3})(16) 158 | 159 | def conv(self, in_planes, out_planes, kernel_size=3, stride=1): 160 | return torch.nn.Sequential( 161 | torch.nn.ZeroPad2d((kernel_size - 1) // 2), 162 | torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=0), 163 | torch.nn.SELU(inplace=True) 164 | ) 165 | 166 | def upconv(self, in_planes, out_planes): 167 | return torch.nn.Sequential( 168 | torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1), 169 | torch.nn.SELU(inplace=True) 170 | ) 171 | 172 | def unproject(self, d, R, t): 173 | tl = d.shape[0] 174 | bs = d.shape[1] 175 | xyz = d.view(tl, bs, -1, 1) * self.ray 176 | xyz = xyz - t.view(tl, bs, 1, 3) 177 | xyz = torch.matmul(xyz, R) 178 | 179 | return xyz 180 | 181 | def change_view_angle(self, xyz, R, t): 182 | xyz_changed = torch.matmul(xyz, R.transpose(1,2)) 183 | xyz_changed = xyz_changed + t.unsqueeze(1).unsqueeze(0) 184 | 185 | return xyz_changed 186 | 187 | def gather_warped_xyz(self, tidx, xyz, depth, flow, amb): 188 | tl = xyz.shape[0] 189 | bs = xyz.shape[1] 190 | 191 | frame_inds = [i for i in range(tl) if i != tidx] 192 | 193 | warped_xyz = [] 194 | warped_xyz.append(xyz[tidx].transpose(1, 2).view(bs, 3, self.core_height, self.core_width)) 195 | 196 | warped_mask = [] 197 | warped_mask.append(torch.ones(bs, 1, self.core_height, self.core_width).to(xyz.device)) 198 | 199 | for j in frame_inds: 200 | warped_xyz.append(warp(xyz[j].transpose(1, 2).view(bs, 3, self.core_height, self.core_width), flow[f'flow_{tidx}{j}'])) 201 | 202 | with torch.no_grad(): 203 | flow0 = flow[f'flow_{tidx}{j}'].detach() 204 | flow10 = warp(flow[f'flow_{j}{tidx}'].detach(), flow0) 205 | fb_mask = ((flow0.detach() + flow10) ** 2).sum(dim=1) < 0.5 + 0.01 * ( 206 | (flow0.detach() ** 2).sum(dim=1) + (flow10 ** 2).sum(dim=1)) 207 | fb_mask = fb_mask.type(torch.float32).unsqueeze(1) 208 | 209 | warped_mask.append(fb_mask) 210 | 211 | warped_xyz = torch.stack(warped_xyz, dim=0) 212 | warped_mask = torch.stack(warped_mask, dim=0) 213 | 214 | return warped_xyz, warped_mask 215 | 216 | def pre_process(self, input_data, d, checkpoint_var): 217 | out_conv1 = self.conv1(torch.cat([input_data, d], dim= 1)) 218 | out_conv2 = self.conv2(out_conv1) 219 | 220 | out_conv3 = self.conv3(out_conv2) 221 | out_conv4 = self.conv4(out_conv3) 222 | 223 | out_res1 = self.res1(out_conv4) 224 | out_res2 = self.res2(out_res1) 225 | feat = self.res3(out_res2) 226 | 227 | return feat 228 | 229 | def process_amb(self, amb, feat): 230 | out_amb_conv = self.amb_conv(merge_tl_bs(amb)) 231 | out_amb_res1 = self.amb_res1(out_amb_conv) 232 | out_amb_res2 = self.amb_res2(out_amb_res1) 233 | 234 | out_process_upc = self.process_upc(feat, out_amb_res2) 235 | 236 | return out_process_upc 237 | 238 | def process_upc(self, feat, out_process_amb): 239 | # out_upconv1 = self.upconv1(feat) 240 | # out_upconv2 = self.upconv2(out_upconv1) 241 | # out_upconv = out_upconv2 242 | 243 | # out_upconv1 = self.upconv1(feat) 244 | # out_upconv = out_upconv1 245 | 246 | out_upconv = torch.nn.functional.interpolate(feat, size=(self.im_height, self.im_width), mode='bilinear', 247 | align_corners=True) 248 | 249 | out_ref_conv = self.ref_conv(torch.cat([out_upconv, out_process_amb], dim=1)) 250 | 251 | return out_ref_conv 252 | 253 | def post_process(self, feat, amb): 254 | out_process_amb = checkpoint(self.process_amb, amb, feat, preserve_rng_state= False) 255 | # out_process_amb = self.process_amb(amb, feat) 256 | 257 | out_ref_res1 = checkpoint(self.ref_res1, out_process_amb, preserve_rng_state= False) 258 | # out_ref_res1 = self.ref_res1(out_process_amb) 259 | out_ref_res2 = checkpoint(self.ref_res2, out_ref_res1, preserve_rng_state= False) 260 | # out_ref_res2 = self.ref_res2(out_ref_res1) 261 | out_ref_res3 = checkpoint(self.ref_res3, out_ref_res2, preserve_rng_state= False) 262 | # out_ref_res3 = self.ref_res3(out_ref_res2) 263 | 264 | out_final_conv = self.final_conv(out_ref_res3) 265 | disp = self.predict_disp(out_final_conv) 266 | 267 | return disp 268 | 269 | def tforward(self, ir, amb, d, depth, R, t, flow): 270 | tl = ir.shape[0] 271 | bs = ir.shape[1] 272 | 273 | input_data = merge_tl_bs(torch.cat((ir, amb), 2)) 274 | checkpoint_var = torch.tensor([0.0]).to(input_data.device).requires_grad_(True) 275 | # feat = checkpoint(self.pre_process, input_data, merge_tl_bs(d), checkpoint_var, preserve_rng_state= False) 276 | feat = self.pre_process(input_data, merge_tl_bs(d), checkpoint_var) 277 | 278 | ###### Block Part 279 | core_feat = split_tl_bs(feat, tl, bs) 280 | core_depth = resize_like(depth, core_feat) 281 | core_flow = resize_flow_like(flow, core_feat) 282 | core_amb = resize_like(amb, core_feat) 283 | xyz = self.unproject(core_depth, R, t) 284 | 285 | warped_xyz = [] 286 | warped_mask = [] 287 | for tidx in range(tl): 288 | xyz_changed = self.change_view_angle(xyz, R[tidx], t[tidx]) 289 | w_xyz, w_mask = self.gather_warped_xyz(tidx, xyz_changed, core_depth, core_flow, core_amb) 290 | 291 | warped_xyz.append(w_xyz) 292 | warped_mask.append(w_mask) 293 | warped_xyz = torch.stack(warped_xyz, dim=0) 294 | warped_mask = torch.stack(warped_mask, dim=0) 295 | 296 | for block in self.blocks: 297 | core_feat = block(core_feat, warped_xyz, warped_mask, core_flow) 298 | 299 | feat = merge_tl_bs(core_feat) 300 | #### End of Block Part 301 | 302 | disp = self.post_process(feat, amb) 303 | out = disp.view(tl, bs, *disp.shape[1:]) 304 | 305 | return out 306 | 307 | class Block2D3D(TimedModule): 308 | def __init__(self, channels, tl): 309 | super(Block2D3D, self).__init__(mod_name='Block2D3D') 310 | 311 | self.channels = channels 312 | self.tl = tl 313 | 314 | self.conv_mf = self.conv(self.channels * self.tl, self.channels, kernel_size=1, stride=1, activation='none') 315 | 316 | self.conv1_1 = self.conv(self.channels, self.channels, kernel_size=3, stride=1, activation='relu') 317 | self.conv1_2 = self.conv(self.channels, self.channels, kernel_size=3, stride=1, activation='relu') 318 | 319 | self.conv2_1 = self.conv(self.channels, self.channels, kernel_size=4, stride=2, activation='relu') 320 | self.conv2_2 = self.conv(self.channels, self.channels, kernel_size=3, stride=1, activation='relu') 321 | 322 | self.conv_fuse = self.conv(self.channels * 3, self.channels, kernel_size=3, stride=1, activation='none') 323 | 324 | # self.conv_res = self.conv(self.channels, self.channels, kernel_size=3, stride=1, activation='relu') 325 | self.activation_res = torch.nn.SELU(inplace=True) 326 | 327 | self.conv3d_1 = Conv3D(channels_in= self.channels, channels_out= self.channels, tl= self.tl, stride= 2) 328 | self.conv3d_2 = Conv3D(channels_in= self.channels, channels_out= self.channels, tl= self.tl, stride= 1) 329 | 330 | def conv(self, in_planes, out_planes, kernel_size=3, stride=1, activation = 'none'): 331 | if activation == 'none': 332 | return torch.nn.Sequential( 333 | torch.nn.ZeroPad2d((kernel_size - 1) // 2), 334 | torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=0), 335 | # torch.nn.BatchNorm2d(out_planes) 336 | torch.nn.GroupNorm(num_groups= 1, num_channels= out_planes) 337 | ) 338 | elif activation == 'relu': 339 | return torch.nn.Sequential( 340 | torch.nn.ZeroPad2d((kernel_size - 1) // 2), 341 | torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=0), 342 | torch.nn.SELU(inplace=True), 343 | # torch.nn.BatchNorm2d(out_planes), 344 | torch.nn.GroupNorm(num_groups=1, num_channels=out_planes) 345 | ) 346 | 347 | def gather_warped_feat(self, tidx, feat, flow): 348 | tl = feat.shape[0] 349 | 350 | frame_inds = [i for i in range(tl) if i != tidx] 351 | 352 | warped_feat = [] 353 | warped_feat.append(feat[tidx]) 354 | 355 | for j in frame_inds: 356 | warped_feat.append(warp(feat[j], flow[f'flow_{tidx}{j}'])) 357 | 358 | warped_feat = torch.stack(warped_feat, dim=0) 359 | 360 | return warped_feat 361 | 362 | def tforward(self, feat, warped_xyz, warped_mask, flow): 363 | self.flow = flow 364 | 365 | out_conv3d_1, warped_feat = checkpoint(self.fwd_3d_1, feat, warped_xyz, warped_mask, preserve_rng_state= False) 366 | # out_conv3d_1, warped_feat = self.fwd_3d_1(feat, warped_xyz, warped_mask) 367 | 368 | out_conv3d_2, _ = checkpoint(self.fwd_3d_2, out_conv3d_1, warped_xyz, warped_mask, preserve_rng_state= False) 369 | # out_conv3d_2, _ = self.fwd_3d_2(out_conv3d_1, warped_xyz, warped_mask) 370 | 371 | out = checkpoint(self.fwd_2d, feat, warped_feat, warped_mask, out_conv3d_2, preserve_rng_state= False) 372 | # out = self.fwd_2d(feat, warped_feat, warped_mask, out_conv3d_2) 373 | 374 | return out 375 | 376 | def fwd_3d_1(self, feat, warped_xyz, warped_mask): 377 | tl = feat.shape[0] 378 | 379 | warped_feat = [] 380 | out_conv3d = [] 381 | for tidx in range(tl): 382 | warped_feat.append(self.gather_warped_feat(tidx, feat, self.flow)) 383 | out_conv3d.append(self.conv3d_1(warped_xyz[tidx], warped_feat[-1], warped_mask[tidx])) 384 | warped_feat = torch.stack(warped_feat, dim=0) 385 | out_conv3d = torch.stack(out_conv3d, dim=0) 386 | 387 | return out_conv3d, warped_feat 388 | 389 | def fwd_3d_2(self, feat, warped_xyz, warped_mask): 390 | tl = feat.shape[0] 391 | 392 | resized_flow = resize_flow_like(self.flow, feat) 393 | resized_warped_xyz = resize_like(warped_xyz, feat) 394 | resized_warped_mask = (resize_like(warped_mask, feat) > 0.5).float() 395 | 396 | warped_feat = [] 397 | out_conv3d = [] 398 | for tidx in range(tl): 399 | warped_feat.append(self.gather_warped_feat(tidx, feat, resized_flow)) 400 | out_conv3d.append(self.conv3d_2(resized_warped_xyz[tidx], warped_feat[-1], resized_warped_mask[tidx])) 401 | warped_feat = torch.stack(warped_feat, dim=0) 402 | out_conv3d = torch.stack(out_conv3d, dim=0) 403 | 404 | return out_conv3d, warped_feat 405 | 406 | def fwd_2d(self, feat, warped_feat, warped_mask, out_conv3d_2): 407 | tl = feat.shape[0] 408 | bs = feat.shape[1] 409 | 410 | warped_feat_2d = (warped_feat * warped_mask / warped_mask.mean(dim=1, keepdim=True)).transpose(1, 2) 411 | warped_feat_2d = warped_feat_2d.reshape(tl * bs, -1, *warped_feat_2d.shape[4:]) 412 | 413 | out_conv_mf = self.conv_mf(warped_feat_2d) 414 | 415 | out_conv1_1 = self.conv1_1(out_conv_mf) 416 | out_conv1_2 = self.conv1_2(out_conv1_1) 417 | 418 | out_conv2_1 = self.conv2_1(out_conv_mf) 419 | out_conv2_2 = self.conv2_2(out_conv2_1) 420 | out_ups2 = torch.nn.functional.interpolate(out_conv2_2, scale_factor=2, mode='bilinear', align_corners=True) 421 | 422 | out_ups3d = torch.nn.functional.interpolate(merge_tl_bs(out_conv3d_2), scale_factor=2, mode='bilinear', align_corners=True) 423 | 424 | ### Fusion Part 425 | out_fuse = torch.cat((out_conv1_2, out_ups2, out_ups3d), dim= 1) 426 | out_conv_fuse = self.conv_fuse(out_fuse) 427 | 428 | out = self.activation_res(split_tl_bs(out_conv_fuse, tl, bs) + feat) 429 | 430 | return out 431 | 432 | class Conv3D(TimedModule): 433 | def __init__(self, channels_in, channels_out, neighbors = 9, tl = 4, ksize = 3, stride = 1, radius_sq = 0.04): 434 | super(Conv3D, self).__init__(mod_name='Conv3D') 435 | self.channels_in = channels_in 436 | self.channels_out = channels_out 437 | self.neighbors = neighbors 438 | self.tl = tl 439 | self.ksize = ksize 440 | self.stride = stride 441 | self.radius_sq = radius_sq 442 | 443 | self.dense1 = self.dense(3, self.channels_out // 2, activation= 'selu') 444 | self.dense2 = self.dense(self.channels_out // 2, self.channels_out, activation= 'selu') 445 | 446 | self.w = torch.nn.Parameter(torch.zeros([self.channels_out, self.channels_out])) 447 | torch.nn.init.xavier_uniform_(self.w, gain=0.1) 448 | 449 | self.activation = torch.nn.SELU(inplace=True) 450 | # self.bn = torch.nn.BatchNorm2d(channels_out) 451 | self.bn = torch.nn.GroupNorm(num_groups= 1, num_channels= channels_out) 452 | 453 | def dense(self, in_planes, out_planes, activation = 'selu'): 454 | if activation == 'selu': 455 | return torch.nn.Sequential( 456 | torch.nn.Linear(in_planes, out_planes), 457 | torch.nn.SELU(inplace=True), 458 | ) 459 | elif activation == 'softmax': 460 | return torch.nn.Sequential( 461 | torch.nn.Linear(in_planes, out_planes), 462 | torch.nn.Softmax(dim = -1) 463 | ) 464 | elif activation == 'none': 465 | return torch.nn.Sequential( 466 | torch.nn.Linear(in_planes, out_planes), 467 | ) 468 | 469 | def tforward(self, xyz, feat, mask, checkpoint_var = 0): 470 | padding_len = (self.ksize - 1) // 2 471 | 472 | xyz = torch.nn.functional.pad(xyz, (padding_len, padding_len, padding_len, padding_len), mode='constant', value=0) 473 | feat = torch.nn.functional.pad(feat, (padding_len, padding_len, padding_len, padding_len), mode='constant', value=0) 474 | mask = torch.nn.functional.pad(mask, (padding_len, padding_len, padding_len, padding_len), mode='constant', value=0) 475 | 476 | xyz = xyz.unfold(3, self.ksize, self.stride).unfold(4, self.ksize, self.stride) 477 | feat = feat.unfold(3, self.ksize, self.stride).unfold(4, self.ksize, self.stride) 478 | mask = mask.unfold(3, self.ksize, self.stride).unfold(4, self.ksize, self.stride) 479 | 480 | xyz = xyz.permute(1, 3, 4, 5, 6, 0, 2) # (bs, h, w, k, k, tl, c) 481 | feat = feat.permute(1, 3, 4, 5, 6, 0, 2) 482 | mask = mask.permute(1, 3, 4, 5, 6, 0, 2) 483 | 484 | bs_h_w = xyz.shape[0:3] 485 | xyz = xyz.reshape(-1, self.ksize * self.ksize * self.tl, xyz.shape[-1]) # (?, k*k*tl, c) 486 | feat = feat.reshape(-1, self.ksize * self.ksize * self.tl, feat.shape[-1]) 487 | mask = mask.reshape(-1, self.ksize * self.ksize * self.tl, mask.shape[-1]) 488 | 489 | xyz_plane = xyz / (xyz[..., 2:] + 1e-12) 490 | 491 | tidx = ((self.ksize ** 2) // 2) * self.tl 492 | xyz_local = xyz - xyz[:, tidx:tidx + 1, :] 493 | xyz_plane_local = xyz_plane - xyz_plane[:, tidx:tidx + 1, :] 494 | 495 | xyz_sq = (xyz_plane_local ** 2).sum(dim=-1, keepdim=True) 496 | 497 | xyz_max_copy = (mask * xyz_sq) + (1 - mask) * (xyz_sq.max() + 1) 498 | _, neighbors_ind = torch.topk(xyz_max_copy, self.neighbors, dim= 1, largest=False, sorted=False) 499 | xyz_neighbors = torch.gather(xyz_local, dim = 1, index= neighbors_ind.expand(-1, -1, xyz_local.shape[-1])) 500 | feat_neighbors = torch.gather(feat, dim = 1, index= neighbors_ind.expand(-1, -1, feat.shape[-1])) 501 | 502 | out_dense1 = self.dense1(xyz_neighbors) 503 | out_dense2 = self.dense2(out_dense1) 504 | 505 | feat_weighted = (out_dense2 * feat_neighbors).sum(dim = 1) 506 | 507 | out_conv = torch.matmul(feat_weighted, self.w).view(*bs_h_w, self.channels_out).permute(0, 3, 1, 2) 508 | out_conv = self.activation(out_conv) 509 | 510 | out = self.bn(out_conv) 511 | 512 | return out 513 | 514 | class ResNetBlock(TimedModule): 515 | def __init__(self, planes): 516 | super(ResNetBlock, self).__init__(mod_name='ResNetBlock') 517 | 518 | self.pad = torch.nn.ZeroPad2d(1) 519 | 520 | self.conv1 = torch.nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=0) 521 | # self.bn1 = torch.nn.BatchNorm2d(planes) 522 | self.bn1 = torch.nn.GroupNorm(num_groups= 1, num_channels= planes) 523 | self.relu1 = torch.nn.SELU(inplace=True) 524 | self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=0) 525 | # self.bn2 = torch.nn.BatchNorm2d(planes) 526 | self.bn2 = torch.nn.GroupNorm(num_groups= 1, num_channels= planes) 527 | self.relu2 = torch.nn.SELU(inplace=True) 528 | 529 | def forward(self, x): 530 | identity = x.clone() 531 | 532 | out = self.conv1(self.pad(x)) 533 | out = self.relu1(out) 534 | out = self.bn1(out) 535 | 536 | out = self.conv2(self.pad(out)) 537 | out = self.bn2(out) 538 | 539 | out += identity 540 | out = self.relu2(out) 541 | 542 | return out -------------------------------------------------------------------------------- /model/multi_frame_worker.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # MIT License 5 | # 6 | # Copyright, 2021 ams International AG 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | import torch 27 | import numpy as np 28 | import logging 29 | import itertools 30 | import matplotlib.pyplot as plt 31 | import co 32 | 33 | from data import base_dataset 34 | from data import dataset 35 | 36 | from model import networks 37 | 38 | from . import worker 39 | 40 | class Worker(worker.Worker): 41 | def __init__(self, args, **kwargs): 42 | super().__init__(args, **kwargs) 43 | 44 | self.disparity_loss = networks.DisparitySmoothLoss() 45 | 46 | def get_train_set(self): 47 | train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=self.track_length, load_flow_data = True, load_primary_data = True, load_pseudo_gt = False, data_type = self.data_type) 48 | return train_set 49 | 50 | def get_test_sets(self): 51 | test_sets = base_dataset.TestSets() 52 | test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=False, track_length=self.track_length, load_flow_data = True, load_primary_data = True, load_pseudo_gt = self.use_pseudo_gt, data_type = self.data_type) 53 | test_sets.append('simple', test_set, test_frequency=1) 54 | 55 | self.patterns = [] 56 | self.ph_losses = [] 57 | self.ge_losses = [] 58 | self.d2ds = [] 59 | 60 | self.lcn_in = self.lcn_in.to('cuda') 61 | for sidx in range(len(test_set.imsizes)): 62 | imsize = test_set.imsizes[sidx] 63 | pat = test_set.patterns[sidx] 64 | pat = pat.mean(axis=2) 65 | pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda') 66 | pat,_ = self.lcn_in(pat) 67 | 68 | self.patterns.append(pat) 69 | 70 | pat = torch.cat([pat for idx in range(3)], dim=1) 71 | ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat) 72 | 73 | K = test_set.getK(sidx) 74 | Ki = np.linalg.inv(K) 75 | K = torch.from_numpy(K) 76 | Ki = torch.from_numpy(Ki) 77 | ge_loss = networks.Multi_Frame_Flow_Consistency_Loss(K, Ki, imsize[0], imsize[1], clamp=0.1) 78 | 79 | self.ph_losses.append( ph_loss ) 80 | self.ge_losses.append( ge_loss ) 81 | 82 | d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline)) 83 | self.d2ds.append( d2d ) 84 | 85 | return test_sets 86 | 87 | def net_forward(self, net, flow): 88 | im0 = self.data['im0'] 89 | ambient0 = self.data['ambient0'] 90 | disp0 = self.data['primary_disp'] 91 | R = self.data['R'] 92 | t = self.data['t'] 93 | 94 | d2d = self.d2ds[0] 95 | depth = d2d(disp0) 96 | 97 | self.primary_disp = disp0[0, 0, 0, ...].detach().cpu().numpy() 98 | 99 | out = net(im0, ambient0, disp0, depth, R, t, flow) 100 | 101 | return out 102 | 103 | def loss_forward(self, out, train, flow_out = None): 104 | if not(isinstance(out, tuple) or isinstance(out, list)): 105 | out = [out] 106 | 107 | vals = [] 108 | 109 | # apply photometric loss 110 | for s,o in zip(itertools.count(), out): 111 | im = self.data[f'im0'] 112 | im = im.view(-1, *im.shape[2:]) 113 | o = o.view(-1, *o.shape[2:]) 114 | std = self.data[f'std0'] 115 | std = std.view(-1, *std.shape[2:]) 116 | val, pattern_proj = self.ph_losses[0](o, im[:,0:1,...], std) 117 | vals.append(val / (2 ** s)) 118 | 119 | # apply disparity loss 120 | for s, o in zip(itertools.count(), out): 121 | if s == 0: 122 | amb0 = self.data[f'ambient0'] 123 | amb0 = amb0.contiguous().view(-1, *amb0.shape[2:]) 124 | o = o.view(-1, *o.shape[2:]) 125 | val = self.disparity_loss(o, amb0) 126 | vals.append(val * 0.8 / (2 ** s)) 127 | 128 | # apply geometric loss 129 | self.flow_mask = None 130 | R = self.data['R'] 131 | t = self.data['t'] 132 | primary_disp = self.data['primary_disp'] 133 | amb = self.data['ambient0'] 134 | 135 | ge_num = self.track_length * (self.track_length-1) / 2 136 | for sidx in range(1): 137 | d2d = self.d2ds[0] 138 | depth = d2d(out[sidx]) 139 | primary_depth = d2d(primary_disp) 140 | ge_loss = self.ge_losses[0] 141 | for tidx0 in range(depth.shape[0]): 142 | for tidx1 in range(tidx0+1, depth.shape[0]): 143 | depth0 = depth[tidx0] 144 | R0 = R[tidx0] 145 | t0 = t[tidx0] 146 | amb0 = amb[tidx0] 147 | primary_depth0 = primary_depth[tidx0] 148 | flow0 = flow_out[f'flow_{tidx0}{tidx1}'] 149 | depth1 = depth[tidx1] 150 | R1 = R[tidx1] 151 | t1 = t[tidx1] 152 | amb1 = amb[tidx1] 153 | primary_depth1 = primary_depth[tidx1] 154 | flow1 = flow_out[f'flow_{tidx1}{tidx0}'] 155 | 156 | val = ge_loss(depth0, depth1, R0, t0, R1, t1, flow0, flow1, amb0, amb1, primary_depth0, primary_depth1) 157 | vals.append(val * 0.2 / ge_num / (2 ** sidx)) 158 | 159 | # warming up the network for a few epochs 160 | if train: 161 | if self.current_epoch < 2: 162 | for s, o in zip(itertools.count(), out): 163 | if s == 0: 164 | val = torch.mean(torch.abs(o - self.data['primary_disp'])) 165 | vals.append(val * 0.1) 166 | 167 | # warming up the network for a few epochs 168 | if self.current_epoch < self.warmup_epochs and self.data_type == 'real': 169 | for s, o in zip(itertools.count(), out): 170 | if s == 0: 171 | valid_mask = (self.data['sgm_disp'] > 30).float() 172 | val = torch.sum(torch.abs(o - self.data['sgm_disp'] + 1.5 * torch.randn(o.size()).cuda()) * valid_mask) / torch.sum(valid_mask) 173 | vals.append(val * 0.1) 174 | 175 | return vals 176 | 177 | def numpy_in_out(self, output): 178 | if not(isinstance(output, tuple) or isinstance(output, list)): 179 | output = [output] 180 | es = output[0].detach().to('cpu').numpy() 181 | gt = self.data['disp0'].detach().to('cpu').numpy().astype(np.float32) 182 | im = self.data['im0'][:,:,0:1,...].detach().to('cpu').numpy() 183 | amb = self.data['ambient0'].detach().to('cpu').numpy() 184 | pat = self.patterns[0].detach().to('cpu').numpy() 185 | 186 | es = es * (gt > 0) 187 | 188 | return es, gt, im, amb, pat 189 | 190 | def write_img(self, out_path, es, gt, im, amb, pat): 191 | logging.info(f'write img {out_path}') 192 | 193 | diff = np.abs(es - gt) 194 | 195 | vmin, vmax = np.nanmin(gt), np.nanmax(gt) 196 | vmin = vmin - 0.2*(vmax-vmin) 197 | vmax = vmax + 0.2*(vmax-vmin) 198 | 199 | vmax = np.max([vmax, 16]) 200 | 201 | fig = plt.figure(figsize=(16,16)) 202 | # plot pattern and input images 203 | ax = plt.subplot(3,3,1); plt.imshow(pat, vmin=pat.min(), vmax=pat.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Projector Pattern') 204 | ax = plt.subplot(3,3,2); plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 IR Input') 205 | ax = plt.subplot(3,3,3); plt.imshow(amb[0], vmin=amb.min(), vmax=amb.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Ambient Input') 206 | 207 | # plot disparities, ground truth disparity is shown only for reference 208 | es0 = co.cmap.color_depth_map(es[0], scale=vmax) 209 | gt0 = co.cmap.color_depth_map(gt[0], scale=vmax) 210 | diff0 = co.cmap.color_error_image(diff[0], BGR=True) 211 | 212 | ax = plt.subplot(3,3,4); plt.imshow(gt0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity GT {np.nanmin(gt[0]):.4f}/{np.nanmax(gt[0]):.4f}') 213 | ax = plt.subplot(3,3,5); plt.imshow(es0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Est. {es[0].min():.4f}/{es[0].max():.4f}') 214 | ax = plt.subplot(3,3,6); plt.imshow(diff0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Err. {diff[0].mean():.5f}') 215 | 216 | es1 = co.cmap.color_depth_map(self.primary_disp, scale=vmax) 217 | gt1 = co.cmap.color_depth_map(gt[0], scale=vmax) 218 | diff1_np = np.abs(self.primary_disp - gt[0]) 219 | diff1 = co.cmap.color_error_image(diff1_np, BGR=True) 220 | ax = plt.subplot(3,3,7); plt.imshow(gt1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity GT {np.nanmin(gt[0]):.4f}/{np.nanmax(gt[0]):.4f}') 221 | ax = plt.subplot(3,3,8); plt.imshow(es1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Input. {self.primary_disp.min():.4f}/{self.primary_disp.max():.4f}') 222 | ax = plt.subplot(3,3,9); plt.imshow(diff1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Input Err. {diff1_np.mean():.5f}') 223 | 224 | plt.tight_layout() 225 | plt.savefig(str(out_path)) 226 | plt.close(fig) 227 | 228 | def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks): 229 | if batch_idx % 256 == 0: 230 | out_path = self.exp_output_dir / f'train_{epoch:03d}_{batch_idx:04d}.png' 231 | es, gt, im, amb, pat = self.numpy_in_out(output) 232 | self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], amb[:, 0, 0], pat[0, 0]) 233 | torch.cuda.empty_cache() 234 | 235 | def callback_test_start(self, epoch, set_idx): 236 | self.metric = co.metric.MultipleMetric( 237 | co.metric.DistanceMetric(vec_length=1), 238 | co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5]) 239 | ) 240 | 241 | def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks): 242 | es, gt, im, amb, pat = self.numpy_in_out(output) 243 | 244 | if batch_idx % 8 == 0: 245 | out_path = self.exp_output_dir / f'test_{epoch:03d}_{batch_idx:04d}.png' 246 | self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], amb[:, 0, 0], pat[0, 0]) 247 | 248 | es = self.crop_reshape(es) 249 | gt = self.crop_reshape(gt) 250 | 251 | self.metric.add(es, gt) 252 | 253 | def crop_reshape(self, input): 254 | output = input.reshape(-1, 1) 255 | return output 256 | 257 | def callback_test_stop(self, epoch, set_idx, loss): 258 | logging.info(f'{self.metric}') 259 | for k, v in self.metric.items(): 260 | self.metric_add_test(epoch, set_idx, k, v) 261 | 262 | if __name__ == '__main__': 263 | pass 264 | -------------------------------------------------------------------------------- /model/networks.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # 5 | # MIT License 6 | # 7 | # Copyright (c) 2019 autonomousvision 8 | # 9 | # Permission is hereby granted, free of charge, to any person obtaining a copy 10 | # of this software and associated documentation files (the "Software"), to deal 11 | # in the Software without restriction, including without limitation the rights 12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | # copies of the Software, and to permit persons to whom the Software is 14 | # furnished to do so, subject to the following conditions: 15 | # 16 | # The above copyright notice and this permission notice shall be included in all 17 | # copies or substantial portions of the Software. 18 | # 19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | # SOFTWARE. 26 | # 27 | # 28 | # MIT License 29 | # 30 | # Copyright, 2021 ams International AG 31 | # 32 | # Permission is hereby granted, free of charge, to any person obtaining a copy 33 | # of this software and associated documentation files (the "Software"), to deal 34 | # in the Software without restriction, including without limitation the rights 35 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 36 | # copies of the Software, and to permit persons to whom the Software is 37 | # furnished to do so, subject to the following conditions: 38 | # 39 | # The above copyright notice and this permission notice shall be included in all 40 | # copies or substantial portions of the Software. 41 | # 42 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 43 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 44 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 45 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 46 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 47 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 48 | # SOFTWARE. 49 | 50 | import torch 51 | import torch.nn.functional as F 52 | import numpy as np 53 | 54 | import co 55 | 56 | from . import ext_functions 57 | 58 | class TimedModule(torch.nn.Module): 59 | def __init__(self, mod_name): 60 | super().__init__() 61 | self.mod_name = mod_name 62 | 63 | def tforward(self, *args, **kwargs): 64 | raise Exception('not implemented') 65 | 66 | def forward(self, *args, **kwargs): 67 | torch.cuda.synchronize() 68 | with co.gtimer.Ctx(self.mod_name): 69 | x = self.tforward(*args, **kwargs) 70 | torch.cuda.synchronize() 71 | return x 72 | 73 | 74 | class PosOutput(TimedModule): 75 | def __init__(self, channels_in, type, im_height, im_width, alpha=1, beta=0, gamma=1, offset=0): 76 | super().__init__(mod_name='PosOutput') 77 | self.im_width = im_width 78 | self.im_width = im_width 79 | 80 | if type == 'pos': 81 | self.layer = torch.nn.Sequential( 82 | torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1), 83 | SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset) 84 | ) 85 | elif type == 'pos_row': 86 | self.layer = torch.nn.Sequential( 87 | MultiLinear(im_height, channels_in, 1), 88 | SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset) 89 | ) 90 | 91 | self.u_pos = None 92 | 93 | def tforward(self, x): 94 | if self.u_pos is None: 95 | self.u_pos = torch.arange(x.shape[3], dtype=torch.float32).view(1,1,1,-1) 96 | self.u_pos = self.u_pos.to(x.device) 97 | pos = self.layer(x) 98 | disp = self.u_pos - pos 99 | return disp 100 | 101 | 102 | class OutputLayerFactory(object): 103 | ''' 104 | Define type of output 105 | type options: 106 | linear: apply only conv channel, used for the edge decoder 107 | disp: estimate the disparity 108 | disp_row: independently estimate the disparity per row 109 | pos: estimate the absolute location 110 | pos_row: independently estimate the absolute location per row 111 | ''' 112 | def __init__(self, type='disp', params={}): 113 | self.type = type 114 | self.params = params 115 | 116 | def __call__(self, channels_in, imsize = None): 117 | 118 | if self.type == 'linear': 119 | return torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1) 120 | 121 | elif self.type == 'disp': 122 | return torch.nn.Sequential( 123 | torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1), 124 | SigmoidAffine(**self.params) 125 | ) 126 | 127 | elif self.type == 'disp_row': 128 | return torch.nn.Sequential( 129 | MultiLinear(imsize[0], channels_in, 1), 130 | SigmoidAffine(**self.params) 131 | ) 132 | 133 | elif self.type == 'pos' or self.type == 'pos_row': 134 | return PosOutput(channels_in, **self.params) 135 | 136 | else: 137 | raise Exception('unknown output layer type') 138 | 139 | 140 | class SigmoidAffine(TimedModule): 141 | def __init__(self, alpha=1, beta=0, gamma=1, offset=0): 142 | super().__init__(mod_name='SigmoidAffine') 143 | self.alpha = alpha 144 | self.beta = beta 145 | self.gamma = gamma 146 | self.offset = offset 147 | 148 | def tforward(self, x): 149 | return torch.sigmoid(x/self.gamma - self.offset) * self.alpha + self.beta 150 | 151 | 152 | class MultiLinear(TimedModule): 153 | def __init__(self, n, channels_in, channels_out): 154 | super().__init__(mod_name='MultiLinear') 155 | self.channels_out = channels_out 156 | self.mods = torch.nn.ModuleList() 157 | for idx in range(n): 158 | self.mods.append(torch.nn.Linear(channels_in, channels_out)) 159 | 160 | def tforward(self, x): 161 | x = x.permute(2,0,3,1) # BxCxHxW => HxBxWxC 162 | y = x.new_empty(*x.shape[:-1], self.channels_out) 163 | for hidx in range(x.shape[0]): 164 | y[hidx] = self.mods[hidx](x[hidx]) 165 | y = y.permute(1,3,0,2) # HxBxWxC => BxCxHxW 166 | return y 167 | 168 | 169 | 170 | class DispNetS(TimedModule): 171 | ''' 172 | Disparity Decoder based on DispNetS 173 | ''' 174 | def __init__(self, channels_in, imsizes, output_facs, coordconv=False, weight_init=False, channel_multiplier=1): 175 | super(DispNetS, self).__init__(mod_name='DispNetS') 176 | 177 | conv_planes = channel_multiplier * np.array( [32, 64, 128, 256, 512, 512, 512] ) 178 | self.conv1 = self.downsample_conv(channels_in, conv_planes[0], kernel_size=7) 179 | self.conv2 = self.downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5) 180 | self.conv3 = self.downsample_conv(conv_planes[1], conv_planes[2]) 181 | self.conv4 = self.downsample_conv(conv_planes[2], conv_planes[3]) 182 | self.conv5 = self.downsample_conv(conv_planes[3], conv_planes[4]) 183 | self.conv6 = self.downsample_conv(conv_planes[4], conv_planes[5]) 184 | self.conv7 = self.downsample_conv(conv_planes[5], conv_planes[6]) 185 | 186 | upconv_planes = channel_multiplier * np.array( [512, 512, 256, 128, 64, 32, 16] ) 187 | self.upconv7 = self.upconv(conv_planes[6], upconv_planes[0]) 188 | self.upconv6 = self.upconv(upconv_planes[0], upconv_planes[1]) 189 | self.upconv5 = self.upconv(upconv_planes[1], upconv_planes[2]) 190 | self.upconv4 = self.upconv(upconv_planes[2], upconv_planes[3]) 191 | self.upconv3 = self.upconv(upconv_planes[3], upconv_planes[4]) 192 | self.upconv2 = self.upconv(upconv_planes[4], upconv_planes[5]) 193 | self.upconv1 = self.upconv(upconv_planes[5], upconv_planes[6]) 194 | 195 | self.iconv7 = self.conv(upconv_planes[0] + conv_planes[5], upconv_planes[0]) 196 | self.iconv6 = self.conv(upconv_planes[1] + conv_planes[4], upconv_planes[1]) 197 | self.iconv5 = self.conv(upconv_planes[2] + conv_planes[3], upconv_planes[2]) 198 | self.iconv4 = self.conv(upconv_planes[3] + conv_planes[2], upconv_planes[3]) 199 | self.iconv3 = self.conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4]) 200 | self.iconv2 = self.conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5]) 201 | self.iconv1 = self.conv(1 + upconv_planes[6], upconv_planes[6]) 202 | 203 | 204 | if isinstance(output_facs, list): 205 | self.predict_disp4 = output_facs[3](upconv_planes[3], imsizes[3]) 206 | self.predict_disp3 = output_facs[2](upconv_planes[4], imsizes[2]) 207 | self.predict_disp2 = output_facs[1](upconv_planes[5], imsizes[1]) 208 | self.predict_disp1 = output_facs[0](upconv_planes[6], imsizes[0]) 209 | else: 210 | self.predict_disp4 = output_facs(upconv_planes[3], imsizes[3]) 211 | self.predict_disp3 = output_facs(upconv_planes[4], imsizes[2]) 212 | self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1]) 213 | self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0]) 214 | 215 | # def init_weights(self): 216 | # for m in self.modules(): 217 | # if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d): 218 | # torch.nn.init.xavier_uniform_(m.weight, gain=0.1) 219 | # if m.bias is not None: 220 | # torch.nn.init.zeros_(m.bias) 221 | 222 | def downsample_conv(self, in_planes, out_planes, kernel_size=3): 223 | return torch.nn.Sequential( 224 | torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2), 225 | torch.nn.ReLU(inplace=True), 226 | torch.nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2), 227 | torch.nn.ReLU(inplace=True) 228 | ) 229 | 230 | def conv(self, in_planes, out_planes): 231 | return torch.nn.Sequential( 232 | torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1), 233 | torch.nn.ReLU(inplace=True) 234 | ) 235 | 236 | def upconv(self, in_planes, out_planes): 237 | return torch.nn.Sequential( 238 | torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1), 239 | torch.nn.ReLU(inplace=True) 240 | ) 241 | 242 | def crop_like(self, input, ref): 243 | assert(input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3)) 244 | return input[:, :, :ref.size(2), :ref.size(3)] 245 | 246 | def tforward(self, x): 247 | out_conv1 = self.conv1(x) 248 | out_conv2 = self.conv2(out_conv1) 249 | out_conv3 = self.conv3(out_conv2) 250 | out_conv4 = self.conv4(out_conv3) 251 | out_conv5 = self.conv5(out_conv4) 252 | out_conv6 = self.conv6(out_conv5) 253 | out_conv7 = self.conv7(out_conv6) 254 | 255 | out_upconv7 = self.crop_like(self.upconv7(out_conv7), out_conv6) 256 | concat7 = torch.cat((out_upconv7, out_conv6), 1) 257 | out_iconv7 = self.iconv7(concat7) 258 | 259 | out_upconv6 = self.crop_like(self.upconv6(out_iconv7), out_conv5) 260 | concat6 = torch.cat((out_upconv6, out_conv5), 1) 261 | out_iconv6 = self.iconv6(concat6) 262 | 263 | out_upconv5 = self.crop_like(self.upconv5(out_iconv6), out_conv4) 264 | concat5 = torch.cat((out_upconv5, out_conv4), 1) 265 | out_iconv5 = self.iconv5(concat5) 266 | 267 | out_upconv4 = self.crop_like(self.upconv4(out_iconv5), out_conv3) 268 | concat4 = torch.cat((out_upconv4, out_conv3), 1) 269 | out_iconv4 = self.iconv4(concat4) 270 | disp4 = self.predict_disp4(out_iconv4) 271 | 272 | out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2) 273 | disp4_up = self.crop_like(torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2) 274 | concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1) 275 | out_iconv3 = self.iconv3(concat3) 276 | disp3 = self.predict_disp3(out_iconv3) 277 | 278 | out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1) 279 | disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1) 280 | concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1) 281 | out_iconv2 = self.iconv2(concat2) 282 | disp2 = self.predict_disp2(out_iconv2) 283 | 284 | out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x) 285 | disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x) 286 | concat1 = torch.cat((out_upconv1, disp2_up), 1) 287 | out_iconv1 = self.iconv1(concat1) 288 | disp1 = self.predict_disp1(out_iconv1) 289 | 290 | out1 = disp1 291 | out2 = torch.nn.functional.interpolate(input=disp2, size=(out1.size(2), out1.size(3)), mode='bilinear', align_corners=False) 292 | out3 = torch.nn.functional.interpolate(input=disp3, size=(out1.size(2), out1.size(3)), mode='bilinear', align_corners=False) 293 | out4 = torch.nn.functional.interpolate(input=disp4, size=(out1.size(2), out1.size(3)), mode='bilinear', align_corners=False) 294 | 295 | return (out1, out2, out3, out4) 296 | 297 | class DispDecoder(TimedModule): 298 | ''' 299 | Disparity Decoder 300 | ''' 301 | def __init__(self, *args, max_disp=128, **kwargs): 302 | super(DispDecoder, self).__init__(mod_name='DispDecoder') 303 | 304 | output_facs_disp = [OutputLayerFactory( type='disp', params={ 'alpha': max_disp/(2**s), 'beta': 0, 'gamma': 1, 'offset': 3}) for s in range(4)] 305 | self.disp_decoder = DispNetS(*args, output_facs=output_facs_disp, **kwargs) 306 | 307 | def tforward(self, x): 308 | disp = self.disp_decoder(x) 309 | return disp 310 | 311 | class DispToDepth(TimedModule): 312 | def __init__(self, focal_length, baseline): 313 | super().__init__(mod_name='DispToDepth') 314 | self.baseline_focal_length = baseline * focal_length 315 | 316 | def tforward(self, disp): 317 | disp = torch.nn.functional.relu(disp) + 1e-12 318 | depth = self.baseline_focal_length / disp 319 | return depth 320 | 321 | class PosToDepth(DispToDepth): 322 | def __init__(self, focal_length, baseline, im_height, im_width): 323 | super().__init__(focal_length, baseline) 324 | self.mod_name = 'PosToDepth' 325 | 326 | self.im_height = im_height 327 | self.im_width = im_width 328 | self.u_pos = torch.arange(im_width, dtype=torch.float32).view(1,1,1,-1) 329 | 330 | def tforward(self, pos): 331 | self.u_pos = self.u_pos.to(pos.device) 332 | disp = self.u_pos - pos 333 | return super().forward(disp) 334 | 335 | 336 | class RectifiedPatternSimilarityLoss(TimedModule): 337 | ''' 338 | Photometric Loss 339 | ''' 340 | def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5): 341 | super().__init__(mod_name='RectifiedPatternSimilarityLoss') 342 | self.im_height = im_height 343 | self.im_width = im_width 344 | self.pattern = pattern.mean(dim=1, keepdim=True).contiguous() 345 | 346 | u, v = np.meshgrid(range(im_width), range(im_height)) 347 | uv0 = np.stack((u,v), axis=2).reshape(-1,1) 348 | uv0 = uv0.astype(np.float32).reshape(1,-1,2) 349 | self.uv0 = torch.from_numpy(uv0) 350 | 351 | self.loss_type = loss_type 352 | self.loss_eps = loss_eps 353 | 354 | def tforward(self, disp0, im, std=None, output_mean=True): 355 | self.pattern = self.pattern.to(disp0.device) 356 | self.uv0 = self.uv0.to(disp0.device) 357 | 358 | uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:]) 359 | uv1 = torch.empty_like(uv0) 360 | uv1[...,0] = uv0[...,0] - disp0.contiguous().view(disp0.shape[0],-1) 361 | uv1[...,1] = uv0[...,1] 362 | 363 | uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5) 364 | uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5) 365 | uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone() 366 | pattern = self.pattern.expand(disp0.shape[0], *self.pattern.shape[1:]) 367 | pattern_proj = torch.nn.functional.grid_sample(pattern, uv1, padding_mode='border', align_corners=True) 368 | mask = torch.ones_like(im) 369 | if std is not None: 370 | mask = mask*std 371 | 372 | diff = ext_functions.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps) 373 | if output_mean: 374 | val = (mask*diff).sum() / mask.sum() 375 | else: 376 | val = diff 377 | return val, pattern_proj 378 | 379 | class SSIM(TimedModule): 380 | """Layer to compute the SSIM loss between a pair of images 381 | """ 382 | def __init__(self): 383 | super().__init__(mod_name='SSIM') 384 | self.mu_x_pool = torch.nn.AvgPool2d(3, 1) 385 | self.mu_y_pool = torch.nn.AvgPool2d(3, 1) 386 | self.sig_x_pool = torch.nn.AvgPool2d(3, 1) 387 | self.sig_y_pool = torch.nn.AvgPool2d(3, 1) 388 | self.sig_xy_pool = torch.nn.AvgPool2d(3, 1) 389 | 390 | self.refl = torch.nn.ReflectionPad2d(1) 391 | 392 | self.C1 = 0.01 ** 2 393 | self.C2 = 0.03 ** 2 394 | 395 | def tforward(self, x, y): 396 | x = self.refl(x) 397 | y = self.refl(y) 398 | 399 | mu_x = self.mu_x_pool(x) 400 | mu_y = self.mu_y_pool(y) 401 | 402 | sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 403 | sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 404 | sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y 405 | 406 | SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) 407 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) 408 | 409 | return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) 410 | 411 | class DisparitySmoothLoss(TimedModule): 412 | ''' 413 | Depth Smooth Loss 414 | ''' 415 | def __init__(self): 416 | super().__init__(mod_name='DepthSmoothLoss') 417 | self.sobel = SobelFilter(norm=False, ksize=5) 418 | 419 | def tforward(self, disp, im): 420 | self.sobel=self.sobel.to(disp.device) 421 | 422 | mean_disp = disp.mean(2, True).mean(3, True) 423 | # norm_disp = disp / (mean_disp + 1e-7) 424 | norm_disp = disp 425 | 426 | grad = self.sobel(norm_disp) 427 | grad_im = self.sobel(im) 428 | 429 | val = torch.abs(grad * torch.exp(-torch.abs(255 * grad_im))) 430 | 431 | return val.mean() 432 | 433 | class ProjectionBaseLoss(TimedModule): 434 | ''' 435 | Base module of the Geometric Loss 436 | ''' 437 | def __init__(self, K, Ki, im_height, im_width): 438 | super().__init__(mod_name='ProjectionBaseLoss') 439 | 440 | self.K = K.view(-1,3,3) 441 | 442 | self.im_height = im_height 443 | self.im_width = im_width 444 | 445 | u, v = np.meshgrid(range(im_width), range(im_height)) 446 | uv = np.stack((u,v,np.ones_like(u)), axis=2).reshape(-1,3) 447 | 448 | ray = uv @ Ki.numpy().T 449 | 450 | ray = ray.reshape(1,-1,3).astype(np.float32) 451 | self.ray = torch.from_numpy(ray) 452 | self.u = torch.from_numpy(u.astype('float32')) 453 | self.v = torch.from_numpy(v.astype('float32')) 454 | 455 | def transform(self, xyz, R=None, t=None): 456 | if t is not None: 457 | bs = xyz.shape[0] 458 | xyz = xyz - t.reshape(bs,1,3) 459 | if R is not None: 460 | xyz = torch.bmm(xyz, R) 461 | return xyz 462 | 463 | def unproject(self, depth, R=None, t=None): 464 | self.ray = self.ray.to(depth.device) 465 | bs = depth.shape[0] 466 | 467 | xyz = depth.reshape(bs,-1,1) * self.ray 468 | xyz = self.transform(xyz, R, t) 469 | return xyz 470 | 471 | def project(self, xyz, R, t, return_ray_format= False): 472 | self.K = self.K.to(xyz.device) 473 | bs = xyz.shape[0] 474 | 475 | xyz = torch.bmm(xyz, R.transpose(1,2)) 476 | xyz = xyz + t.reshape(bs,1,3) 477 | 478 | if return_ray_format: 479 | uv = xyz 480 | else: 481 | Kt = self.K.transpose(1,2).expand(bs,-1,-1) 482 | uv = torch.bmm(xyz, Kt) 483 | 484 | d = uv[:,:,2:3] 485 | 486 | # avoid division by zero 487 | uv = uv[:,:,:2] / (torch.nn.functional.relu(d) + 1e-12) 488 | return uv, d 489 | 490 | 491 | def tforward(self, depth0, R0, t0, R1, t1, return_ray_format= False): 492 | xyz = self.unproject(depth0, R0, t0) 493 | return self.project(xyz, R1, t1, return_ray_format) 494 | 495 | 496 | class ProjectionDepthSimilarityLoss(ProjectionBaseLoss): 497 | ''' 498 | Geometric Loss 499 | ''' 500 | def __init__(self, *args, clamp=-1): 501 | super().__init__(*args) 502 | self.mod_name = 'ProjectionDepthSimilarityLoss' 503 | self.clamp = clamp 504 | 505 | def fwd(self, depth0, depth1, R0, t0, R1, t1): 506 | self.u = self.u.to(depth0.device) 507 | self.v = self.v.to(depth0.device) 508 | 509 | uv1, d1 = super().tforward(depth0, R0, t0, R1, t1) 510 | uv1 = uv1.view(-1, self.im_height, self.im_width, 2) 511 | d1 = d1.view(-1, 1, self.im_height, self.im_width) 512 | 513 | # Calculation rigid flow 514 | rigid_flow = uv1.clone() 515 | rigid_flow[..., 0] -= self.u 516 | rigid_flow[..., 1] -= self.v 517 | rigid_flow = rigid_flow.permute(0, 3, 1, 2) 518 | 519 | uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5) 520 | uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5) 521 | depth10 = torch.nn.functional.grid_sample(depth1, uv1, padding_mode='border', align_corners=True) 522 | 523 | diff = torch.abs(d1 - depth10) 524 | 525 | if self.clamp > 0: 526 | diff = torch.clamp(diff, 0, self.clamp) 527 | 528 | orig_mask = (diff.detach() < self.clamp).to('cpu').numpy().astype('float32') 529 | 530 | return diff.mean(), rigid_flow, orig_mask[0][0] 531 | 532 | def tforward(self, depth0, depth1, R0, t0, R1, t1): 533 | l0, rigid_flow0, orig_mask = self.fwd(depth0, depth1, R0, t0, R1, t1) 534 | l1, rigid_flow1, _ = self.fwd(depth1, depth0, R1, t1, R0, t0) 535 | 536 | with torch.no_grad(): 537 | mask0 = self.generate_mask(rigid_flow0.detach(), rigid_flow1.detach()) 538 | mask1 = self.generate_mask(rigid_flow1.detach(), rigid_flow0.detach()) 539 | 540 | return l0+l1, rigid_flow0, rigid_flow1, mask0, mask1, orig_mask 541 | 542 | def generate_mask(self, flow0, flow1): 543 | uv1 = flow0.clone().permute(0, 2, 3, 1) 544 | uv1[..., 0] += self.u 545 | uv1[..., 1] += self.v 546 | uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5) 547 | uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5) 548 | flow0_proj = torch.nn.functional.grid_sample(flow1, uv1, padding_mode='border', align_corners=True) 549 | mask0 = ((flow0 + flow0_proj) ** 2).sum(dim=1) < 0.25 + 0.02 * ((flow0 ** 2).sum(dim=1) + (flow0_proj ** 2).sum(dim=1)) 550 | mask0 = mask0.type(torch.float32).unsqueeze(1) 551 | return mask0 552 | 553 | 554 | class Multi_Frame_Flow_Consistency_Loss(ProjectionBaseLoss): 555 | ''' 556 | Flow Consistency Loss for Multi-Frame Inference 557 | ''' 558 | 559 | def __init__(self, *args, clamp=-1): 560 | super().__init__(*args) 561 | self.mod_name = 'Multi_Frame_Flow_Consistency_Loss' 562 | self.clamp = clamp 563 | 564 | def fwd(self, depth0, depth1, R0, t0, R1, t1, flow0, flow1, amb0, amb1, primary_depth1): 565 | self.u = self.u.to(depth0.device) 566 | self.v = self.v.to(depth0.device) 567 | 568 | uv1, d1 = super().tforward(depth0, R0, t0, R1, t1) 569 | uv1 = uv1.view(-1, self.im_height, self.im_width, 2) 570 | d1 = d1.view(-1, 1, self.im_height, self.im_width) 571 | 572 | uv1_flow = flow0.permute(0, 2, 3, 1).clone() 573 | uv1_flow[..., 0] += self.u 574 | uv1_flow[..., 1] += self.v 575 | 576 | uv1_flow[..., 0] = 2 * (uv1_flow[..., 0] / (self.im_width - 1) - 0.5) 577 | uv1_flow[..., 1] = 2 * (uv1_flow[..., 1] / (self.im_height - 1) - 0.5) 578 | depth10 = torch.nn.functional.grid_sample(depth1, uv1_flow, padding_mode='zeros', align_corners=True) 579 | 580 | diff = torch.abs(d1 - depth10) 581 | 582 | with torch.no_grad(): 583 | flow10 = torch.nn.functional.grid_sample(flow1.detach(), uv1_flow.detach(), padding_mode='zeros', align_corners=True) 584 | fb_mask = ((flow0.detach() + flow10) ** 2).sum(dim=1) < 0.5 + 0.02 * ( 585 | (flow0.detach() ** 2).sum(dim=1) + (flow10 ** 2).sum(dim=1)) 586 | fb_mask = fb_mask.type(torch.float32).unsqueeze(1) 587 | 588 | amb10 = torch.nn.functional.grid_sample(amb1.detach(), uv1_flow.detach(), padding_mode='zeros', align_corners=True) 589 | vc_mask = (((amb0 - amb10).abs()).mean(dim=1, keepdim=True) < 0.01).type(torch.float32) 590 | 591 | uv0, d0 = super().tforward(primary_depth1.detach(), R1.detach(), t1.detach(), R0.detach(), t0.detach(), return_ray_format= False) 592 | uv0 = uv0.view(-1, self.im_height, self.im_width, 2).permute(0, 3, 1, 2) 593 | warped_uv0 = torch.nn.functional.grid_sample(uv0.detach(), uv1_flow.detach(), padding_mode='zeros', align_corners=True) 594 | self_uv = torch.stack([self.u, self.v], dim= 0).unsqueeze(0) 595 | rf_mask = (((warped_uv0 - self_uv) ** 2).sum(dim = 1, keepdim=True) < 1).type(torch.float32) 596 | 597 | loss_mask = fb_mask * vc_mask * rf_mask 598 | 599 | diff = (diff * loss_mask).sum() / (loss_mask.sum() + 1e-8) 600 | 601 | return diff 602 | 603 | def tforward(self, depth0, depth1, R0, t0, R1, t1, flow0, flow1, amb0, amb1, primary_depth0, primary_depth1): 604 | l0 = self.fwd(depth0, depth1, R0, t0, R1, t1, flow0, flow1, amb0, amb1, primary_depth1) 605 | l1 = self.fwd(depth1, depth0, R1, t1, R0, t0, flow1, flow0, amb1, amb0, primary_depth0) 606 | 607 | return l0 + l1 608 | 609 | class Single_Frame_Flow_Consistency_Loss(ProjectionBaseLoss): 610 | ''' 611 | Flow Consistency Loss for Single Frame Inference 612 | ''' 613 | 614 | def __init__(self, *args, clamp=-1): 615 | super().__init__(*args) 616 | self.mod_name = 'Single_Frame_Flow_Consistency_Loss' 617 | self.clamp = clamp 618 | 619 | def fwd(self, depth0, depth1, R0, t0, R1, t1, flow0, flow1, amb0, amb1): 620 | self.u = self.u.to(depth0.device) 621 | self.v = self.v.to(depth0.device) 622 | 623 | uv1, d1 = super().tforward(depth0, R0, t0, R1, t1) 624 | uv1 = uv1.view(-1, self.im_height, self.im_width, 2) 625 | d1 = d1.view(-1, 1, self.im_height, self.im_width) 626 | 627 | uv1_flow = flow0.permute(0, 2, 3, 1).clone() 628 | uv1_flow[..., 0] += self.u 629 | uv1_flow[..., 1] += self.v 630 | 631 | uv1_flow[..., 0] = 2 * (uv1_flow[..., 0] / (self.im_width - 1) - 0.5) 632 | uv1_flow[..., 1] = 2 * (uv1_flow[..., 1] / (self.im_height - 1) - 0.5) 633 | depth10 = torch.nn.functional.grid_sample(depth1, uv1_flow, padding_mode='zeros', align_corners=True) 634 | 635 | diff = torch.abs(d1 - depth10) 636 | 637 | if self.clamp > 0: 638 | diff = torch.clamp(diff, 0, self.clamp) 639 | 640 | orig_mask = (diff.detach() < self.clamp).to('cpu').numpy().astype('float32') 641 | 642 | with torch.no_grad(): 643 | flow10 = torch.nn.functional.grid_sample(flow1.detach(), uv1_flow.detach(), padding_mode='zeros', align_corners=True) 644 | fb_mask = ((flow0.detach() + flow10) ** 2).sum(dim=1) < 0.5 + 0.02 * ( 645 | (flow0.detach() ** 2).sum(dim=1) + (flow10 ** 2).sum(dim=1)) 646 | fb_mask = fb_mask.type(torch.float32).unsqueeze(1) 647 | 648 | amb10 = torch.nn.functional.grid_sample(amb1.detach(), uv1_flow.detach(), padding_mode='zeros', align_corners=True) 649 | vc_mask = (((amb0 - amb10).abs()).mean(dim=1, keepdim=True) < 0.01).type(torch.float32) 650 | 651 | loss_mask = fb_mask * vc_mask 652 | 653 | diff = (diff * loss_mask).sum() / (loss_mask.sum() + 1e-8) 654 | 655 | return diff, loss_mask, orig_mask[0][0] 656 | 657 | def tforward(self, depth0, depth1, R0, t0, R1, t1, flow0, flow1, amb0, amb1): 658 | l0, mask0, orig_mask = self.fwd(depth0, depth1, R0, t0, R1, t1, flow0, flow1, amb0, amb1) 659 | l1, mask1, _ = self.fwd(depth1, depth0, R1, t1, R0, t0, flow1, flow0, amb1, amb0) 660 | 661 | return l0 + l1, mask0, mask1, orig_mask 662 | 663 | class LCN(TimedModule): 664 | ''' 665 | Local Contract Normalization 666 | ''' 667 | def __init__(self, radius, epsilon): 668 | super().__init__(mod_name='LCN') 669 | self.box_conv = torch.nn.Sequential( 670 | torch.nn.ReflectionPad2d(radius), 671 | torch.nn.Conv2d(1, 1, kernel_size=2*radius+1, bias=False) 672 | ) 673 | self.box_conv[1].weight.requires_grad=False 674 | self.box_conv[1].weight.fill_(1.) 675 | 676 | self.epsilon = epsilon 677 | self.radius = radius 678 | 679 | def tforward(self, data): 680 | boxs = self.box_conv(data) 681 | 682 | avgs = boxs / (2*self.radius+1)**2 683 | boxs_n2 = boxs**2 684 | boxs_2n = self.box_conv(data**2) 685 | 686 | stds = torch.sqrt(torch.clamp(boxs_2n / (2*self.radius+1)**2 - avgs**2 + 1e-6, min=0)) 687 | stds = stds + self.epsilon 688 | 689 | return (data - avgs) / stds, stds 690 | 691 | 692 | 693 | class SobelFilter(TimedModule): 694 | ''' 695 | Sobel Filter 696 | ''' 697 | def __init__(self, norm=False, ksize = 5): 698 | super(SobelFilter, self).__init__(mod_name='SobelFilter') 699 | self.ksize = ksize 700 | if self.ksize == 5: 701 | kx = np.array([[-5, -4, 0, 4, 5], 702 | [-8, -10, 0, 10, 8], 703 | [-10, -20, 0, 20, 10], 704 | [-8, -10, 0, 10, 8], 705 | [-5, -4, 0, 4, 5]])/240.0 706 | elif self.ksize == 3: 707 | kx = np.array([[-1, 0, 1], 708 | [-2, 0, 2], 709 | [-1, 0, 1]])/8.0 710 | 711 | ky = kx.copy().transpose(1,0) 712 | 713 | self.conv_x=torch.nn.Conv2d(1, 1, kernel_size=self.ksize, stride=1, padding=0, bias=False) 714 | self.conv_x.weight=torch.nn.Parameter(torch.from_numpy(kx).float().unsqueeze(0).unsqueeze(0)) 715 | 716 | self.conv_y=torch.nn.Conv2d(1, 1, kernel_size=self.ksize, stride=1, padding=0, bias=False) 717 | self.conv_y.weight=torch.nn.Parameter(torch.from_numpy(ky).float().unsqueeze(0).unsqueeze(0)) 718 | 719 | self.norm=norm 720 | 721 | def tforward(self,x): 722 | if self.ksize == 5: 723 | x = F.pad(x, (2,2,2,2), "replicate") 724 | elif self.ksize == 3: 725 | x = F.pad(x, (1,1,1,1), "replicate") 726 | gx = self.conv_x(x) 727 | gy = self.conv_y(x) 728 | if self.norm: 729 | return torch.sqrt(gx**2 + gy**2 + 1e-8) 730 | else: 731 | return torch.cat((gx, gy), dim=1) -------------------------------------------------------------------------------- /model/single_frame_worker.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # MIT License 5 | # 6 | # Copyright, 2021 ams International AG 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | import torch 27 | import numpy as np 28 | import logging 29 | import itertools 30 | import matplotlib.pyplot as plt 31 | import co 32 | 33 | from data import base_dataset 34 | from data import dataset 35 | 36 | from model import networks 37 | 38 | from . import worker 39 | 40 | class Worker(worker.Worker): 41 | def __init__(self, args, **kwargs): 42 | super().__init__(args, **kwargs) 43 | 44 | self.disparity_loss = networks.DisparitySmoothLoss() 45 | 46 | def get_train_set(self): 47 | train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=self.track_length, load_flow_data = True, load_primary_data = False, load_pseudo_gt = self.use_pseudo_gt, data_type = self.data_type) 48 | return train_set 49 | 50 | def get_test_sets(self): 51 | test_sets = base_dataset.TestSets() 52 | test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=False, track_length=self.track_length, load_flow_data = True, load_primary_data = False, load_pseudo_gt = self.use_pseudo_gt, data_type = self.data_type) 53 | test_sets.append('simple', test_set, test_frequency=1) 54 | 55 | self.patterns = [] 56 | self.ph_losses = [] 57 | self.ge_losses = [] 58 | self.d2ds = [] 59 | 60 | self.lcn_in = self.lcn_in.to('cuda') 61 | for sidx in range(len(test_set.imsizes)): 62 | imsize = test_set.imsizes[sidx] 63 | pat = test_set.patterns[sidx] 64 | pat = pat.mean(axis=2) 65 | pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda') 66 | pat,_ = self.lcn_in(pat) 67 | 68 | self.patterns.append(pat) 69 | 70 | pat = torch.cat([pat for idx in range(3)], dim=1) 71 | ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat) 72 | 73 | K = test_set.getK(sidx) 74 | Ki = np.linalg.inv(K) 75 | K = torch.from_numpy(K) 76 | Ki = torch.from_numpy(Ki) 77 | ge_loss = networks.Single_Frame_Flow_Consistency_Loss(K, Ki, imsize[0], imsize[1], clamp=0.1) 78 | 79 | self.ph_losses.append( ph_loss ) 80 | self.ge_losses.append( ge_loss ) 81 | 82 | d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline)) 83 | self.d2ds.append( d2d ) 84 | 85 | return test_sets 86 | 87 | def net_forward(self, net, flow = None): 88 | im0 = self.data['im0'] 89 | tl = im0.shape[0] 90 | bs = im0.shape[1] 91 | im0 = im0.view(-1, *im0.shape[2:]) 92 | out = net(im0) 93 | 94 | if not(isinstance(out, tuple) or isinstance(out, list)): 95 | out = out.view(tl, bs, *out.shape[1:]) 96 | else: 97 | out = [o.view(tl, bs, *o.shape[1:]) for o in out] 98 | 99 | return out 100 | 101 | def loss_forward(self, out, train, flow_out = None): 102 | if not(isinstance(out, tuple) or isinstance(out, list)): 103 | out = [out] 104 | 105 | vals = [] 106 | 107 | # apply photometric loss 108 | for s,o in zip(itertools.count(), out): 109 | im = self.data[f'im0'] 110 | im = im.view(-1, *im.shape[2:]) 111 | o = o.view(-1, *o.shape[2:]) 112 | std = self.data[f'std0'] 113 | std = std.view(-1, *std.shape[2:]) 114 | val, pattern_proj = self.ph_losses[0](o, im[:,0:1,...], std) 115 | vals.append(val / (2 ** s)) 116 | 117 | # apply disparity loss 118 | for s, o in zip(itertools.count(), out): 119 | if s == 0: 120 | amb0 = self.data[f'ambient0'] 121 | amb0 = amb0.contiguous().view(-1, *amb0.shape[2:]) 122 | o = o.view(-1, *o.shape[2:]) 123 | val = self.disparity_loss(o, amb0) 124 | vals.append(val * 0.4 / (2 ** s)) 125 | 126 | # apply geometric loss 127 | R = self.data['R'] 128 | t = self.data['t'] 129 | amb = self.data['ambient0'] 130 | ge_num = self.track_length * (self.track_length-1) / 2 131 | for sidx in range(1): 132 | d2d = self.d2ds[0] 133 | depth = d2d(out[sidx]) 134 | ge_loss = self.ge_losses[0] 135 | for tidx0 in range(depth.shape[0]): 136 | for tidx1 in range(tidx0+1, depth.shape[0]): 137 | depth0 = depth[tidx0] 138 | R0 = R[tidx0] 139 | t0 = t[tidx0] 140 | amb0 = amb[tidx0] 141 | flow0 = flow_out[f'flow_{tidx0}{tidx1}'] 142 | depth1 = depth[tidx1] 143 | R1 = R[tidx1] 144 | t1 = t[tidx1] 145 | amb1 = amb[tidx1] 146 | flow1 = flow_out[f'flow_{tidx1}{tidx0}'] 147 | 148 | val, flow_mask0, flow_mask1, orig_mask = ge_loss(depth0, depth1, R0, t0, R1, t1, flow0, flow1, amb0, amb1) 149 | vals.append(val * 0.2 / ge_num / (2 ** sidx)) 150 | 151 | # using pseudo-ground truth 152 | if self.use_pseudo_gt: 153 | for s, o in zip(itertools.count(), out): 154 | val = torch.mean(torch.abs(o - self.data['pseudo_gt'])) 155 | vals.append(val * 0.1 / (2 ** s)) 156 | 157 | # warming up the network for a few epochs 158 | if train and self.data_type == 'real': 159 | if self.current_epoch < self.warmup_epochs: 160 | for s, o in zip(itertools.count(), out): 161 | valid_mask = (self.data['sgm_disp'] > 30).float() 162 | val = torch.sum(torch.abs(o - self.data['sgm_disp'] + 1.5 * torch.randn(o.size()).cuda()) * valid_mask) / torch.sum(valid_mask) 163 | vals.append(val * 0.1) 164 | 165 | return vals 166 | 167 | def numpy_in_out(self, output): 168 | if not(isinstance(output, tuple) or isinstance(output, list)): 169 | output = [output] 170 | es = output[0].detach().to('cpu').numpy() 171 | gt = self.data['disp0'].detach().to('cpu').numpy().astype(np.float32) 172 | im = self.data['im0'][:,:,0:1,...].detach().to('cpu').numpy() 173 | amb = self.data['ambient0'].detach().to('cpu').numpy() 174 | pat = self.patterns[0].detach().to('cpu').numpy() 175 | 176 | es = es * (gt > 0) 177 | 178 | return es, gt, im, amb, pat 179 | 180 | def write_img(self, out_path, es, gt, im, amb, pat): 181 | logging.info(f'write img {out_path}') 182 | 183 | diff = np.abs(es - gt) 184 | 185 | vmin, vmax = np.nanmin(gt), np.nanmax(gt) 186 | vmin = vmin - 0.2*(vmax-vmin) 187 | vmax = vmax + 0.2*(vmax-vmin) 188 | 189 | vmax = np.max([vmax, 16]) 190 | 191 | fig = plt.figure(figsize=(16,16)) 192 | es0 = co.cmap.color_depth_map(es[0], scale=vmax) 193 | gt0 = co.cmap.color_depth_map(gt[0], scale=vmax) 194 | diff0 = co.cmap.color_error_image(diff[0], BGR=True) 195 | 196 | # plot pattern and input images 197 | ax = plt.subplot(3,3,1); plt.imshow(pat, vmin=pat.min(), vmax=pat.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Projector Pattern') 198 | ax = plt.subplot(3,3,2); plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 IR Input') 199 | ax = plt.subplot(3,3,3); plt.imshow(amb[0], vmin=amb.min(), vmax=amb.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Ambient Input') 200 | 201 | # plot disparities, ground truth disparity is shown only for reference 202 | ax = plt.subplot(3,3,4); plt.imshow(gt0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity GT {np.nanmin(gt[0]):.4f}/{np.nanmax(gt[0]):.4f}') 203 | ax = plt.subplot(3,3,5); plt.imshow(es0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Est. {es[0].min():.4f}/{es[0].max():.4f}') 204 | ax = plt.subplot(3,3,6); plt.imshow(diff0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Err. {diff[0].mean():.5f}') 205 | 206 | es1 = co.cmap.color_depth_map(es[1], scale=vmax) 207 | gt1 = co.cmap.color_depth_map(gt[1], scale=vmax) 208 | diff1 = co.cmap.color_error_image(diff[1], BGR=True) 209 | ax = plt.subplot(3,3,7); plt.imshow(gt1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity GT {np.nanmin(gt[1]):.4f}/{np.nanmax(gt[1]):.4f}') 210 | ax = plt.subplot(3,3,8); plt.imshow(es1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Est. {es[1].min():.4f}/{es[1].max():.4f}') 211 | ax = plt.subplot(3,3,9); plt.imshow(diff1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Err. {diff[1].mean():.5f}') 212 | 213 | plt.tight_layout() 214 | plt.savefig(str(out_path)) 215 | plt.close(fig) 216 | 217 | def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks): 218 | if batch_idx % 256 == 0: 219 | out_path = self.exp_output_dir / f'train_{epoch:03d}_{batch_idx:04d}.png' 220 | es, gt, im, amb, pat = self.numpy_in_out(output) 221 | self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], amb[:, 0, 0], pat[0, 0]) 222 | torch.cuda.empty_cache() 223 | 224 | def callback_test_start(self, epoch, set_idx): 225 | self.metric = co.metric.MultipleMetric( 226 | co.metric.DistanceMetric(vec_length=1), 227 | co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5]) 228 | ) 229 | 230 | def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks): 231 | es, gt, im, amb, pat = self.numpy_in_out(output) 232 | 233 | if batch_idx % 8 == 0: 234 | out_path = self.exp_output_dir / f'test_{epoch:03d}_{batch_idx:04d}.png' 235 | self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], amb[:, 0, 0], pat[0, 0]) 236 | 237 | es = self.crop_reshape(es) 238 | gt = self.crop_reshape(gt) 239 | self.metric.add(es, gt) 240 | 241 | def crop_reshape(self, input): 242 | output = input.reshape(-1, 1) 243 | return output 244 | 245 | def callback_test_stop(self, epoch, set_idx, loss): 246 | logging.info(f'{self.metric}') 247 | for k, v in self.metric.items(): 248 | self.metric_add_test(epoch, set_idx, k, v) 249 | 250 | if __name__ == '__main__': 251 | pass 252 | -------------------------------------------------------------------------------- /model/worker.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # 5 | # MIT License 6 | # 7 | # Copyright (c) 2019 autonomousvision 8 | # 9 | # Permission is hereby granted, free of charge, to any person obtaining a copy 10 | # of this software and associated documentation files (the "Software"), to deal 11 | # in the Software without restriction, including without limitation the rights 12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | # copies of the Software, and to permit persons to whom the Software is 14 | # furnished to do so, subject to the following conditions: 15 | # 16 | # The above copyright notice and this permission notice shall be included in all 17 | # copies or substantial portions of the Software. 18 | # 19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | # SOFTWARE. 26 | # 27 | # 28 | # MIT License 29 | # 30 | # Copyright, 2021 ams International AG 31 | # 32 | # Permission is hereby granted, free of charge, to any person obtaining a copy 33 | # of this software and associated documentation files (the "Software"), to deal 34 | # in the Software without restriction, including without limitation the rights 35 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 36 | # copies of the Software, and to permit persons to whom the Software is 37 | # furnished to do so, subject to the following conditions: 38 | # 39 | # The above copyright notice and this permission notice shall be included in all 40 | # copies or substantial portions of the Software. 41 | # 42 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 43 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 44 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 45 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 46 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 47 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 48 | # SOFTWARE. 49 | 50 | import pickle 51 | 52 | import numpy as np 53 | import torch 54 | import os 55 | import random 56 | import logging 57 | import datetime 58 | from pathlib import Path 59 | import argparse 60 | import socket 61 | import gc 62 | import json 63 | import matplotlib.pyplot as plt 64 | import time 65 | from collections import OrderedDict 66 | from model import networks 67 | 68 | 69 | class StopWatch(object): 70 | def __init__(self): 71 | self.timings = OrderedDict() 72 | self.starts = {} 73 | 74 | def start(self, name): 75 | self.starts[name] = time.time() 76 | 77 | def stop(self, name): 78 | if name not in self.timings: 79 | self.timings[name] = [] 80 | self.timings[name].append(time.time() - self.starts[name]) 81 | 82 | def get(self, name=None, reduce=np.sum): 83 | if name is not None: 84 | return reduce(self.timings[name]) 85 | else: 86 | ret = {} 87 | for k in self.timings: 88 | ret[k] = reduce(self.timings[k]) 89 | return ret 90 | 91 | def __repr__(self): 92 | return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()]) 93 | def __str__(self): 94 | return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()]) 95 | 96 | 97 | class ETA(object): 98 | def __init__(self, length): 99 | self.length = length 100 | self.start_time = time.time() 101 | self.current_idx = 0 102 | self.current_time = time.time() 103 | 104 | def update(self, idx): 105 | self.current_idx = idx 106 | self.current_time = time.time() 107 | 108 | def get_elapsed_time(self): 109 | return self.current_time - self.start_time 110 | 111 | def get_item_time(self): 112 | return self.get_elapsed_time() / (self.current_idx + 1) 113 | 114 | def get_remaining_time(self): 115 | return self.get_item_time() * (self.length - self.current_idx + 1) 116 | 117 | def format_time(self, seconds): 118 | minutes, seconds = divmod(seconds, 60) 119 | hours, minutes = divmod(minutes, 60) 120 | hours = int(hours) 121 | minutes = int(minutes) 122 | return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}' 123 | 124 | def get_elapsed_time_str(self): 125 | return self.format_time(self.get_elapsed_time()) 126 | 127 | def get_remaining_time_str(self): 128 | return self.format_time(self.get_remaining_time()) 129 | 130 | class Worker(object): 131 | def __init__(self, args, seed=42, test_batch_size=4, num_workers=4, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1): 132 | self.use_pseudo_gt = args.use_pseudo_gt 133 | self.lcn_radius = args.lcn_radius 134 | self.track_length = args.track_length 135 | self.data_type = args.data_type 136 | # assert(self.track_length>1) 137 | 138 | self.architecture = args.architecture 139 | self.epochs = args.epochs 140 | self.warmup_epochs = args.warmup_epochs 141 | self.seed = seed 142 | self.train_batch_size = args.train_batch_size 143 | self.test_batch_size = test_batch_size 144 | self.num_workers = num_workers 145 | self.save_frequency = save_frequency 146 | self.train_device = train_device 147 | self.test_device = test_device 148 | self.max_train_iter = max_train_iter 149 | 150 | self.errs_list=[] 151 | 152 | config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'config.json')) 153 | with open(config_path) as fp: 154 | config = json.load(fp) 155 | data_root = Path(config['DATA_DIR']) 156 | output_dir = Path(config['OUTPUT_DIR']) 157 | self.settings_path = data_root / 'settings.pkl' 158 | self.output_dir = output_dir 159 | with open(str(self.settings_path), 'rb') as f: 160 | settings = pickle.load(f) 161 | self.baseline = settings['baseline'] 162 | self.K = settings['K'] 163 | self.Ki = np.linalg.inv(self.K) 164 | self.imsizes = [settings['imsize']] 165 | for iter in range(3): 166 | self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2))) 167 | self.ref_pattern = settings['pattern'] 168 | 169 | sample_paths = sorted((data_root).glob('0*/')) 170 | if self.data_type == 'synthetic': 171 | self.train_paths = sample_paths[2**10:] 172 | self.test_paths = sample_paths[2**9:2**10] 173 | self.valid_paths = sample_paths[0:2**9] 174 | elif self.data_type == 'real': 175 | self.test_paths = sample_paths[4::8] 176 | self.train_paths = [path for path in sample_paths if path not in self.test_paths] 177 | 178 | self.lcn_in = networks.LCN(self.lcn_radius, 0.05).cuda() 179 | 180 | self.setup_experiment() 181 | 182 | def setup_experiment(self): 183 | self.exp_output_dir = self.output_dir / self.architecture 184 | self.exp_output_dir.mkdir(parents=True, exist_ok=True) 185 | 186 | if logging.root: del logging.root.handlers[:] 187 | logging.basicConfig( 188 | level=logging.INFO, 189 | handlers=[ 190 | logging.FileHandler( str(self.exp_output_dir / 'train.log') ), 191 | logging.StreamHandler() 192 | ], 193 | format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s' 194 | ) 195 | 196 | logging.info('='*80) 197 | logging.info(f'Start of experiment with architecture: {self.architecture}') 198 | logging.info(socket.gethostname()) 199 | self.log_datetime() 200 | logging.info('='*80) 201 | 202 | self.metric_path = self.exp_output_dir / 'metrics.json' 203 | if self.metric_path.exists(): 204 | with open(str(self.metric_path), 'r') as fp: 205 | self.metric_data = json.load(fp) 206 | else: 207 | self.metric_data = {} 208 | 209 | self.init_seed() 210 | 211 | def metric_add_train(self, epoch, key, val): 212 | epoch = str(epoch) 213 | key = str(key) 214 | if epoch not in self.metric_data: 215 | self.metric_data[epoch] = {} 216 | if 'train' not in self.metric_data[epoch]: 217 | self.metric_data[epoch]['train'] = {} 218 | self.metric_data[epoch]['train'][key] = val 219 | 220 | def metric_add_test(self, epoch, set_idx, key, val): 221 | epoch = str(epoch) 222 | set_idx = str(set_idx) 223 | key = str(key) 224 | if epoch not in self.metric_data: 225 | self.metric_data[epoch] = {} 226 | if 'test' not in self.metric_data[epoch]: 227 | self.metric_data[epoch]['test'] = {} 228 | if set_idx not in self.metric_data[epoch]['test']: 229 | self.metric_data[epoch]['test'][set_idx] = {} 230 | self.metric_data[epoch]['test'][set_idx][key] = val 231 | 232 | def metric_save(self): 233 | with open(str(self.metric_path), 'w') as fp: 234 | json.dump(self.metric_data, fp, indent=2) 235 | 236 | def init_seed(self, seed=None): 237 | if seed is not None: 238 | self.seed = seed 239 | logging.info(f'Set seed to {self.seed}') 240 | np.random.seed(self.seed) 241 | random.seed(self.seed) 242 | torch.manual_seed(self.seed) 243 | torch.cuda.manual_seed(self.seed) 244 | 245 | def log_datetime(self): 246 | logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')) 247 | 248 | def mem_report(self): 249 | for obj in gc.get_objects(): 250 | if torch.is_tensor(obj): 251 | print(type(obj), obj.shape) 252 | 253 | def get_net_path(self, epoch, root=None): 254 | if root is None: 255 | root = self.exp_output_dir 256 | return root / f'net_{epoch:04d}.params' 257 | 258 | def get_do_parser_cmds(self): 259 | return ['retrain', 'resume', 'retest', 'test_init'] 260 | 261 | def get_do_parser(self): 262 | parser = argparse.ArgumentParser() 263 | parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds()) 264 | parser.add_argument('--epoch', type=int, default=-1) 265 | return parser 266 | 267 | def do_cmd(self, args, net, optimizer, scheduler=None): 268 | if args.cmd == 'retrain': 269 | self.train(net, optimizer, resume=False, scheduler=scheduler) 270 | elif args.cmd == 'resume': 271 | self.train(net, optimizer, resume=True, scheduler=scheduler) 272 | elif args.cmd == 'retest': 273 | self.retest(net, epoch=args.epoch) 274 | elif args.cmd == 'test_init': 275 | test_sets = self.get_test_sets() 276 | self.test(-1, net, test_sets) 277 | else: 278 | raise Exception('invalid cmd') 279 | 280 | def do(self, net, optimizer, load_net_optimizer=None, scheduler=None): 281 | parser = self.get_do_parser() 282 | args, _ = parser.parse_known_args() 283 | 284 | if load_net_optimizer is not None and args.cmd not in ['schedule']: 285 | net, optimizer = load_net_optimizer() 286 | 287 | self.do_cmd(args, net, optimizer, scheduler=scheduler) 288 | 289 | def retest(self, net, epoch=-1): 290 | if epoch < 0: 291 | epochs = range(self.epochs) 292 | else: 293 | epochs = [epoch] 294 | 295 | test_sets = self.get_test_sets() 296 | 297 | for epoch in epochs: 298 | net_path = self.get_net_path(epoch) 299 | if net_path.exists(): 300 | state_dict = torch.load(str(net_path)) 301 | net.load_state_dict(state_dict) 302 | self.test(epoch, net, test_sets) 303 | 304 | def format_err_str(self, errs, div=1): 305 | err = sum(errs) 306 | if len(errs) > 1: 307 | err_str = f'{err/div:0.4f}=' + '+'.join([f'{e/div:0.4f}' for e in errs]) 308 | else: 309 | err_str = f'{err/div:0.4f}' 310 | return err_str 311 | 312 | def write_err_img(self): 313 | err_img_path = self.exp_output_dir / 'errs.png' 314 | fig = plt.figure(figsize=(16,16)) 315 | lines=[] 316 | for idx,errs in enumerate(self.errs_list): 317 | line,=plt.plot(range(len(errs)), errs, label=f'error{idx}') 318 | lines.append(line) 319 | plt.tight_layout() 320 | plt.legend(handles=lines) 321 | plt.savefig(str(err_img_path)) 322 | plt.close(fig) 323 | 324 | 325 | def callback_train_new_epoch(self, epoch, net, optimizer): 326 | pass 327 | 328 | def train(self, net, optimizer, resume=False, scheduler=None): 329 | logging.info('='*80) 330 | logging.info('Start training') 331 | self.log_datetime() 332 | logging.info('='*80) 333 | 334 | train_set = self.get_train_set() 335 | test_sets = self.get_test_sets() 336 | 337 | net = net.to(self.train_device) 338 | 339 | epoch = 0 340 | min_err = {ts.name: 1e9 for ts in test_sets} 341 | 342 | state_path = self.exp_output_dir / 'state.dict' 343 | if resume and state_path.exists(): 344 | logging.info('='*80) 345 | logging.info(f'Loading state from {state_path}') 346 | logging.info('='*80) 347 | state = torch.load(str(state_path)) 348 | epoch = state['epoch'] + 1 349 | if 'min_err' in state: 350 | min_err = state['min_err'] 351 | 352 | curr_state = net.state_dict() 353 | curr_state.update(state['state_dict']) 354 | net.load_state_dict(curr_state) 355 | 356 | try: 357 | optimizer.load_state_dict(state['optimizer']) 358 | except: 359 | logging.info('Warning: cannot load optimizer from state_dict') 360 | pass 361 | if 'cpu_rng_state' in state: 362 | torch.set_rng_state(state['cpu_rng_state']) 363 | if 'gpu_rng_state' in state: 364 | torch.cuda.set_rng_state(state['gpu_rng_state']) 365 | 366 | for epoch in range(epoch, self.epochs): 367 | self.current_epoch = epoch 368 | self.callback_train_new_epoch(epoch, net, optimizer) 369 | 370 | # train epoch 371 | self.train_epoch(epoch, net, optimizer, train_set) 372 | 373 | # test epoch 374 | errs = self.test(epoch, net, test_sets) 375 | 376 | if (epoch + 1) % self.save_frequency == 0: 377 | net = net.to(self.train_device) 378 | 379 | state_dict = { 380 | 'epoch': epoch, 381 | 'min_err': min_err, 382 | 'state_dict': net.state_dict(), 383 | 'optimizer': optimizer.state_dict(), 384 | 'cpu_rng_state': torch.get_rng_state(), 385 | 'gpu_rng_state': torch.cuda.get_rng_state(), 386 | } 387 | logging.info(f'save state to {state_path}') 388 | state_path = self.exp_output_dir / 'state.dict' 389 | torch.save(state_dict, str(state_path)) 390 | 391 | for test_set_name in errs: 392 | err = sum(errs[test_set_name]) 393 | if err < min_err[test_set_name]: 394 | min_err[test_set_name] = err 395 | state_path = self.exp_output_dir / f'state_set_{test_set_name}_best.dict' 396 | logging.info(f'save state to {state_path}') 397 | torch.save(state_dict, str(state_path)) 398 | 399 | # store network 400 | net_path = self.get_net_path(epoch) 401 | logging.info(f'save network to {net_path}') 402 | torch.save(net.state_dict(), str(net_path)) 403 | 404 | if scheduler is not None: 405 | scheduler.step() 406 | 407 | logging.info('='*80) 408 | logging.info('Finished training') 409 | self.log_datetime() 410 | logging.info('='*80) 411 | 412 | def get_train_set(self): 413 | raise NotImplementedError() 414 | 415 | def get_test_sets(self): 416 | raise NotImplementedError() 417 | 418 | def copy_data(self, data, device, requires_grad, train): 419 | self.data = {} 420 | 421 | self.lcn_in = self.lcn_in.to(device) 422 | for key, val in data.items(): 423 | # from 424 | # batch_size x track_length x ... 425 | # to 426 | # track_length x batch_size x ... 427 | if len(val.shape)>2: 428 | val = val.transpose(0, 1) 429 | self.data[key] = val.to(device) 430 | if 'im' in key and 'blend' not in key and 'primary' not in key: 431 | im = self.data[key] 432 | tl = im.shape[0] 433 | bs = im.shape[1] 434 | im_lcn,im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:])) 435 | key_std = key.replace('im','std') 436 | self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device) 437 | im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2) 438 | self.data[key] = im_cat 439 | elif key == 'ambient0': 440 | ambient = self.data[key] 441 | tl = ambient.shape[0] 442 | bs = ambient.shape[1] 443 | ambient_lcn, ambient_std = self.lcn_in(ambient.contiguous().view(-1, *ambient.shape[2:])) 444 | ambient_cat = torch.cat((ambient_lcn.view(tl, bs, *ambient.shape[2:]), ambient), dim=2) 445 | self.data[f'{key}_in'] = ambient_cat.to(device).requires_grad_(requires_grad=requires_grad) 446 | 447 | # Mimicing the reference pattern 448 | pat = self.ref_pattern.mean(axis=2) 449 | pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda') 450 | pat_lcn, _ = self.lcn_in(pat) 451 | pat_cat = torch.cat((pat_lcn, pat), dim=1).unsqueeze(0) 452 | self.data[f'ref_pattern'] = pat_cat.repeat([*self.data['im0'].shape[0:2], 1, 1, 1]) 453 | 454 | def net_forward(self, net, train): 455 | raise NotImplementedError() 456 | 457 | def read_optical_flow(self, train): 458 | im = self.data['ambient0'] 459 | out = {} 460 | for tidx0 in range(im.shape[0]): 461 | for tidx1 in range(im.shape[0]): 462 | if tidx0 != tidx1: 463 | out[f'flow_{tidx0}{tidx1}'] = self.data[f'flow_{tidx0}{tidx1}'][0] 464 | 465 | return out 466 | 467 | def loss_forward(self, output, train, flow_out): 468 | raise NotImplementedError() 469 | 470 | def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks): 471 | pass 472 | 473 | def callback_train_start(self, epoch): 474 | pass 475 | 476 | def callback_train_stop(self, epoch, loss): 477 | pass 478 | 479 | def train_epoch(self, epoch, net, optimizer, dset): 480 | self.callback_train_start(epoch) 481 | stopwatch = StopWatch() 482 | 483 | logging.info('='*80) 484 | logging.info('Train epoch %d' % epoch) 485 | 486 | dset.current_epoch = epoch 487 | train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=False) 488 | 489 | net = net.to(self.train_device) 490 | net.train() 491 | 492 | mean_loss = None 493 | 494 | n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader) 495 | bar = ETA(length=n_batches) 496 | 497 | stopwatch.start('total') 498 | stopwatch.start('data') 499 | for batch_idx, data in enumerate(train_loader): 500 | if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break 501 | self.copy_data(data, device=self.train_device, requires_grad=False, train=True) 502 | stopwatch.stop('data') 503 | 504 | optimizer.zero_grad() 505 | 506 | stopwatch.start('forward') 507 | flow_output = self.read_optical_flow(train=True) 508 | output = self.net_forward(net, flow_output) 509 | 510 | if 'cuda' in self.train_device: torch.cuda.synchronize() 511 | stopwatch.stop('forward') 512 | 513 | stopwatch.start('loss') 514 | errs = self.loss_forward(output, True, flow_output) 515 | if isinstance(errs, dict): 516 | masks = errs['masks'] 517 | errs = errs['errs'] 518 | else: 519 | masks = [] 520 | if not isinstance(errs, list) and not isinstance(errs, tuple): 521 | errs = [errs] 522 | err = sum(errs) 523 | if 'cuda' in self.train_device: torch.cuda.synchronize() 524 | stopwatch.stop('loss') 525 | 526 | stopwatch.start('backward') 527 | err.backward() 528 | self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks) 529 | if 'cuda' in self.train_device: torch.cuda.synchronize() 530 | stopwatch.stop('backward') 531 | 532 | # print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), 'GB') 533 | # print('Max Allocated:', round(torch.cuda.max_memory_allocated(0) / 1024 ** 3, 1), 'GB') 534 | # print('Max Cached:', round(torch.cuda.max_memory_cached(0) / 1024 ** 3, 1), 'GB') 535 | 536 | stopwatch.start('optimizer') 537 | optimizer.step() 538 | if 'cuda' in self.train_device: torch.cuda.synchronize() 539 | stopwatch.stop('optimizer') 540 | 541 | bar.update(batch_idx) 542 | if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0: 543 | err_str = self.format_err_str(errs) 544 | logging.info(f'train e{epoch}: {batch_idx+1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}') 545 | #self.write_err_img() 546 | 547 | 548 | if mean_loss is None: 549 | mean_loss = [0 for e in errs] 550 | for erridx, err in enumerate(errs): 551 | mean_loss[erridx] += err.item() 552 | 553 | stopwatch.start('data') 554 | stopwatch.stop('total') 555 | logging.info('timings: %s' % stopwatch) 556 | 557 | mean_loss = [l / len(train_loader) for l in mean_loss] 558 | self.callback_train_stop(epoch, mean_loss) 559 | self.metric_add_train(epoch, 'loss', mean_loss) 560 | 561 | # save metrics 562 | self.metric_save() 563 | 564 | err_str = self.format_err_str(mean_loss) 565 | logging.info(f'avg train_loss={err_str}') 566 | return mean_loss 567 | 568 | def callback_test_start(self, epoch, set_idx): 569 | pass 570 | 571 | def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks): 572 | pass 573 | 574 | def callback_test_stop(self, epoch, set_idx, loss): 575 | pass 576 | 577 | def test(self, epoch, net, test_sets): 578 | errs = {} 579 | for test_set_idx, test_set in enumerate(test_sets): 580 | if (epoch + 1) % test_set.test_frequency == 0: 581 | logging.info('='*80) 582 | logging.info(f'testing set {test_set.name}') 583 | err = self.test_epoch(epoch, test_set_idx, net, test_set.dset) 584 | errs[test_set.name] = err 585 | return errs 586 | 587 | def test_epoch(self, epoch, set_idx, net, dset): 588 | logging.info('-'*80) 589 | logging.info('Test epoch %d' % epoch) 590 | dset.current_epoch = epoch 591 | test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False, pin_memory=False) 592 | 593 | net = net.to(self.test_device) 594 | net.eval() 595 | 596 | with torch.no_grad(): 597 | mean_loss = None 598 | 599 | self.callback_test_start(epoch, set_idx) 600 | 601 | bar = ETA(length=len(test_loader)) 602 | stopwatch = StopWatch() 603 | stopwatch.start('total') 604 | stopwatch.start('data') 605 | for batch_idx, data in enumerate(test_loader): 606 | self.copy_data(data, device=self.test_device, requires_grad=False, train=False) 607 | stopwatch.stop('data') 608 | 609 | stopwatch.start('forward') 610 | flow_output = self.read_optical_flow(train=False) 611 | 612 | output = self.net_forward(net, flow_output) 613 | 614 | if 'cuda' in self.test_device: torch.cuda.synchronize() 615 | stopwatch.stop('forward') 616 | 617 | stopwatch.start('loss') 618 | errs = self.loss_forward(output, False, flow_output) 619 | if isinstance(errs, dict): 620 | masks = errs['masks'] 621 | errs = errs['errs'] 622 | else: 623 | masks = [] 624 | if not isinstance(errs, list) and not isinstance(errs, tuple): 625 | errs = [errs] 626 | 627 | bar.update(batch_idx) 628 | if batch_idx % 25 == 0: 629 | err_str = self.format_err_str(errs) 630 | logging.info(f'test e{epoch}: {batch_idx+1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}') 631 | 632 | if mean_loss is None: 633 | mean_loss = [0 for e in errs] 634 | for erridx, err in enumerate(errs): 635 | mean_loss[erridx] += err.item() 636 | stopwatch.stop('loss') 637 | 638 | self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks) 639 | 640 | stopwatch.start('data') 641 | stopwatch.stop('total') 642 | logging.info('timings: %s' % stopwatch) 643 | 644 | mean_loss = [l / len(test_loader) for l in mean_loss] 645 | self.callback_test_stop(epoch, set_idx, mean_loss) 646 | self.metric_add_test(epoch, set_idx, 'loss', mean_loss) 647 | 648 | # save metrics 649 | self.metric_save() 650 | 651 | err_str = self.format_err_str(mean_loss) 652 | logging.info(f'test epoch {epoch}: avg test_loss={err_str}') 653 | return mean_loss 654 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch 2 | torchvision 3 | cython 4 | numpy 5 | matplotlib 6 | pandas 7 | scipy 8 | opencv 9 | h5py 10 | tqdm 11 | cupy 12 | ipdb -------------------------------------------------------------------------------- /train_val.py: -------------------------------------------------------------------------------- 1 | # DepthInSpace is a PyTorch-based program which estimates 3D depth maps 2 | # from active structured-light sensor's multiple video frames. 3 | # 4 | # MIT License 5 | # 6 | # Copyright, 2021 ams International AG 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | import torch 27 | from model import single_frame_worker 28 | from model import multi_frame_worker 29 | from model import networks 30 | from model import multi_frame_networks 31 | from co.args import parse_args 32 | 33 | torch.backends.cudnn.benchmark = True 34 | 35 | # parse args 36 | args = parse_args() 37 | 38 | # loss types 39 | if args.architecture == 'single_frame': 40 | worker = single_frame_worker.Worker(args) 41 | elif args.architecture == 'multi_frame': 42 | worker = multi_frame_worker.Worker(args) 43 | 44 | if args.use_pseudo_gt and args.architecture != 'single_frame': 45 | print("Using pseudo-gt is only possible in single-frame architecture") 46 | raise NotImplementedError 47 | 48 | # set up network 49 | if args.architecture == 'single_frame': 50 | net = networks.DispDecoder(channels_in=2, max_disp=args.max_disp, imsizes=worker.imsizes) 51 | elif args.architecture == 'multi_frame': 52 | net = multi_frame_networks.FuseNet(imsize=worker.imsizes[0], K=worker.K, baseline=worker.baseline, track_length=worker.track_length, max_disp=args.max_disp) 53 | 54 | # optimizer 55 | opt_parameters = net.parameters() 56 | optimizer = torch.optim.Adam(opt_parameters, lr=1e-4) 57 | 58 | # start the work 59 | worker.do(net, optimizer) --------------------------------------------------------------------------------