├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── arguments.py ├── colors ├── __init__.py ├── cityscapes.py ├── plasma.py └── tango.py ├── dataloader ├── __init__.py ├── data_preprocessing │ ├── __init__.py │ ├── download_kitti.py │ ├── kitti_2015_generate_depth.py │ ├── kitti_archives_to_download.txt │ └── kitti_utils.py ├── definitions │ ├── __init__.py │ └── labels_file.py ├── eval │ ├── __init__.py │ └── metrics.py ├── file_io │ ├── __init__.py │ ├── dir_lister.py │ └── get_path.py └── pt_data_loader │ ├── __init__.py │ ├── basedataset.py │ ├── dataset_parameterset.py │ ├── mytransforms.py │ └── specialdatasets.py ├── dc_masking.py ├── environment.yml ├── eval_depth.py ├── eval_depth.sh ├── eval_pose.py ├── eval_pose.sh ├── eval_segmentation.py ├── eval_segmentation.sh ├── experiments ├── kitti_full.sh ├── kitti_only_depth.sh ├── zhou_full.sh └── zhou_only_depth.sh ├── harness.py ├── imgs ├── ECCV_presentation.jpg ├── intro.png ├── qualitative.png └── sg_depth.gif ├── inference.py ├── loaders ├── __init__.py ├── depth │ ├── __init__.py │ ├── train.py │ └── validation.py ├── fns.py ├── pose │ ├── __init__.py │ └── validation.py └── segmentation │ ├── __init__.py │ ├── train.py │ └── validation.py ├── losses ├── __init__.py ├── baselosses.py ├── depth.py └── segmentation.py ├── models ├── __init__.py ├── layers │ ├── __init__.py │ └── grad_scaling_layers.py ├── networks │ ├── __init__.py │ ├── multi_res_output.py │ ├── partial_decoder.py │ ├── pose_decoder.py │ └── resnet_encoder.py └── sgdepth.py ├── perspective_resample.py ├── requirements.txt ├── state_manager.py ├── timer.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | **/__pycache 3 | *~ 4 | slurm-*.out 5 | *.pth 6 | *.npy 7 | *_losses.json 8 | .idea 9 | test_*.sh 10 | train.sh 11 | 12 | train_proposed_local.sh 13 | eval_depth_local.sh 14 | eval_segmentation_local.sh 15 | eval_pose_local.sh 16 | results.txt 17 | backup 18 | backup/inference_2.py 19 | backup/inference.py 20 | inference_local.sh -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Marvin Klingner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Supervised Monocular Depth Estimation: Solving the Dynamic Object Problem by Semantic Guidance 2 | 3 | [Marvin Klingner](https://www.tu-braunschweig.de/en/ifn/institute/team/sv/klingner), [Jan-Aike Termöhlen](https://www.tu-braunschweig.de/en/ifn/institute/team/sv/termoehlen), Jonas Mikolajczyk, and [Tim Fingscheidt](https://www.tu-braunschweig.de/en/ifn/institute/team/sv/fingscheidt) – ECCV 2020 4 | 5 | 6 | [Link to paper](https://arxiv.org/abs/2007.06936) 7 | 8 | ## ECCV Presentation 9 |

10 | 11 | SGDepth video presentation ECCV 12 | 13 |

14 | 15 | ## Idea Behind the Method 16 | 17 | Self-supervised monocular depth estimation usually relies on the assumption of a static world during training which is violated by dynamic objects. 18 | In our paper we introduce a multi-task learning framework that semantically guides the self-supervised depth estimation to handle such objects. 19 | 20 |

21 | 22 |

23 | 24 | ## Citation 25 | 26 | If you find our work useful or interesting, please consider citing [our paper](https://arxiv.org/abs/2007.06936): 27 | 28 | ``` 29 | @inproceedings{klingner2020selfsupervised, 30 | title = {{Self-Supervised Monocular Depth Estimation: Solving the Dynamic Object Problem by Semantic Guidance}}, 31 | author = {Marvin Klingner and 32 | Jan-Aike Term\"{o}hlen and 33 | Jonas Mikolajczyk and 34 | Tim Fingscheidt 35 | }, 36 | booktitle = {{European Conference on Computer Vision ({ECCV})}}, 37 | year = {2020} 38 | } 39 | ``` 40 | 41 | ## Improved Depth Estimation Results 42 | 43 | As a consequence of the multi-task training, dynamic objects are more clearly shaped and small objects such as traffic signs or traffic lights are better recognised in comparison to previous methods. 44 | 45 |

46 | 47 |

48 | 49 | ## Our models 50 | 51 | | Model | Resolution | Abs Rel | Sq Rel | RMSE | RMSE log | δ < 1.25 | δ < 1.25^2 | δ < 1.25^3 | 52 | |:------------------:|:----------:|:-------:|:------:|:-----:|:--------:|----------|------------|------------| 53 | | [SGDepth only depth](https://drive.google.com/file/d/1KMabblEyvIEFKl4yEueiEKTa-LNGKqE2/view?usp=sharing) | 640x192 | 0.117 | 0.907 | 4.844 | 0.196 | 0.875 | 0.958 | 0.980 | 54 | | [SGDepth full](https://drive.google.com/file/d/1Hd5YTIhSMttAVWyE7a08r3GHlHDX1Z46/view?usp=sharing) | 640x192 | 0.113 | 0.835 | 4.693 | 0.191 | 0.879 | 0.961 | 0.981 | 55 | 56 | 57 | ### Inference Preview: 58 |

59 | 60 |

61 | 62 | ## Prerequisites and Requirements 63 | We recommend to use Anaconda. An `environment.yml` file is provided. 64 | We use PyTorch 1.1 with Cuda 10.0. An `requirements.txt` also exists. Older and newer versions of mentioned packages *may* also work. However, from pytorch 1.3 on, the default behaviour of some functions (e.g. ``grid_sample()``) did change, which needs to be considered when training models with the newest pytorch releases. 65 | To start working with our code, do the following steps: 66 | 67 | 1. In your project dir export the environment variable for the checkpoints: ```export IFN_DIR_CHECKPOINT=/path/to/folder/Checkpoints``` 68 | 2. Download the Cityscapes dataset: *https://www.cityscapes-dataset.com/* and put it in a folder "Dataset" 69 | 3. Download KITTI dataset: *http://www.cvlibs.net/datasets/kitti/* and place it in the same dataset folder. To ensure that you have the same folder structure as we have, you can directly use the script ``dataloader\data_preprocessing\download_kitti.py``. 70 | 4. If you want to evaluate on the KITTI 2015 stereo dataset, also download it from *http://www.cvlibs.net/datasets/kitti/* and apply the ``dataloader\data_preprocessing\kitti_2015_generate_depth.py`` to generate the depth maps. 71 | 5. Prepare the dataset folder: 72 | - export an environment variable to the root directory of all datasets: ```export IFN_DIR_DATASET=/path/to/folder/Dataset``` 73 | - Place the json files in your ```cityscapes``` folder. Please take care, that the folder is spelled exactly as given here: ```"cityscapes"```. 74 | - Place the json files in your ```kitti``` folder. Please take care, that the folder is spelled exactly as given here: ```"kitti"```. 75 | - Place the json files in your ```kitti_zhou_split``` folder. Please take care, that the folder is spelled exactly as given here: ```"kitti_zhou_split"```. 76 | - Place the json files in your ```kitti_kitti_split``` folder. Please take care, that the folder is spelled exactly as given here: ```"kitti_kitti_split"```. 77 | - Place the json files in your ```kitti_2015``` folder containing the KITTI 2015 Stereo dataset. Please take care, that the folder is spelled exactly as given here: ```"kitti_2015"```. 78 | 79 | For further information please also refer to our dataloader: Dataloader Repository 80 | 81 | ## Inference 82 | The inference script is working independently. It just imports the model and the arguments. 83 | It inferences all images in a given directory and outputs them to defined directory. 84 | ``` 85 | python3 inference.py \ 86 | --model-path sgdepth_eccv_test/zhou_full/epoch_20/model.pth \ 87 | --inference-resize-height 192 \ 88 | --inference-resize-width 640 \ 89 | --image-path /path/to/input/dir \ 90 | --output-path /path/to/output/dir 91 | ``` 92 | You can also define output with `--output-format .png` or `.jpg`. 93 | 94 | 95 | ## Depth Evaluation 96 | For evaluation of the predicted depth use `eval_depth.py`. 97 | Specify which model to use with the `--model-name` and the `--model-load` flag. The path is relative from the exported checkpoint directory. 98 | An example is shown below: 99 | ``` 100 | python3 eval_depth.py\ 101 | --sys-best-effort-determinism \ 102 | --model-name "eval_kitti_depth" \ 103 | --model-load sgdepth_eccv_test/zhou_full/checkpoints/epoch_20 \ 104 | --depth-validation-loaders "kitti_zhou_test" 105 | ``` 106 | Additionally an example script is shown in `eval_depth.sh` 107 | 108 | ## Segmentation Evaluation 109 | For the evaluation of the segmentation results on Cityscapes use the `eval_segmentation.py` 110 | 111 | ``` 112 | python3 eval_segmentation.py \ 113 | --sys-best-effort-determinism \ 114 | --model-name "eval_kitti_seg" \ 115 | --model-load sgdepth_eccv_test/zhou_full/checkpoints/epoch_20 \ 116 | --segmentation-validation-loaders "cityscapes_validation" \ 117 | --segmentation-validation-resize-width 1024 118 | --segmentation-validation-resize-height 512 \ 119 | --eval-num-images 1 120 | ``` 121 | Additionally an example script is shown in `eval_segmentation.sh` 122 | 123 | ## Training 124 | To train the model use `train.py`: 125 | ``` 126 | python3 train.py \ 127 | --model-name zhou_full \ 128 | --depth-training-loaders "kitti_zhou_train" \ 129 | --train-batches-per-epoch 7293 \ 130 | --masking-enable \ 131 | --masking-from-epoch 15 \ 132 | --masking-linear-increase 133 | ``` 134 | If you have any questions feel free to contact us! 135 | 136 | ## License 137 | This code is licensed under the MIT-License feel free to use it within the boundaries of this license. 138 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifnspaml/SGDepth/7d594b5153801d8240ec809ef6065eff1cd1f4fc/__init__.py -------------------------------------------------------------------------------- /colors/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as functional 3 | 4 | from .cityscapes import COLOR_SCHEME_CITYSCAPES 5 | from .plasma import COLOR_SCHEME_PLASMA 6 | from .tango import COLOR_SCHEME_TANGO 7 | 8 | def seg_prob_image(probs): 9 | """Takes a torch tensor of shape (N, C, H, W) containing a map 10 | of cityscapes class probabilities C as input and generate a 11 | color image of shape (N, C, H, W) from it. 12 | """ 13 | 14 | # Choose the number of categories according 15 | # to the dimesion of the input Tensor 16 | colors = COLOR_SCHEME_CITYSCAPES[:probs.shape[1]] 17 | 18 | # Make the category channel the last dimesion (N, W, H, C), 19 | # matrix multiply so that the color channel is the last 20 | # dimesion and restore the shape array to (N, C, H, W). 21 | image = (probs.transpose(1, -1) @ colors).transpose(-1, 1) 22 | 23 | return image 24 | 25 | def domain_prob_image(probs): 26 | """Takes a torch tensor of shape (N, C, H, W) containing a map 27 | of domain probabilities C as input and generate a 28 | color image of shape (N, C, H, W) from it. 29 | """ 30 | 31 | # Choose the number of categories according 32 | # to the dimesion of the input Tensor 33 | colors = COLOR_SCHEME_TANGO[:probs.shape[1]] 34 | 35 | # Make the category channel the last dimesion (N, W, H, C), 36 | # matrix multiply so that the color channel is the last 37 | # dimesion and restore the shape array to (N, C, H, W). 38 | image = (probs.transpose(1, -1) @ colors).transpose(-1, 1) 39 | 40 | return image 41 | 42 | def seg_idx_image(idxs): 43 | """Takes a torch tensor of shape (N, H, W) containing a map 44 | of cityscapes train ids as input and generate a color image 45 | of shape (N, C, H, W) from it. 46 | """ 47 | 48 | # Take the dimesionality from (N, H, W) to (N, C, H, W) 49 | # and make the tensor invariant over the C dimension 50 | idxs = idxs.unsqueeze(1) 51 | idxs = idxs.expand(-1, 3, -1, -1) 52 | 53 | h, w = idxs.shape[2:] 54 | 55 | # Extend the dimesionality of the color scheme from 56 | # (IDX, C) to (IDX, C, H, W) and make it invariant over 57 | # the last two dimensions. 58 | color = COLOR_SCHEME_CITYSCAPES.unsqueeze(2).unsqueeze(3) 59 | color = color.expand(-1, -1, h, w) 60 | 61 | image = torch.gather(color, 0, idxs) 62 | 63 | return image 64 | 65 | def _depth_to_percentile_normalized_disp(depth): 66 | """This performs the same steps as normalize_depth_for_display 67 | from the SfMLearner repository, given the default options. 68 | This treads every image in the batch separately. 69 | """ 70 | 71 | disp = 1 / (depth + 1e-6) 72 | 73 | disp_sorted, _ = disp.flatten(1).sort(1) 74 | idx = disp_sorted.shape[1] * 95 // 100 75 | batch_percentiles = disp_sorted[:,idx].view(-1, 1, 1, 1) 76 | 77 | disp_norm = disp / (batch_percentiles + 1e-6) 78 | 79 | return disp_norm 80 | 81 | def depth_norm_image(depth): 82 | """Takes a torch tensor of shape (N, H, W) containing depth 83 | as input and outputs normalized depth images colored with the 84 | matplotlib plasma color scheme. 85 | """ 86 | 87 | # Perform the kind-of-industry-standard 88 | # normalization for image generation 89 | disp = _depth_to_percentile_normalized_disp(depth) 90 | 91 | # We generate two indexing maps into the colors 92 | # tensor and a map to interpolate between these 93 | # two indexed colors. 94 | # First scale the disp tensor from [0, 1) to [0, num_colors). 95 | # Then take the dimesionality from (N, H, W) to (N, C, H, W) 96 | # and make it invariant over the C dimension 97 | num_colors = COLOR_SCHEME_PLASMA.shape[0] 98 | idx = disp * num_colors 99 | idx = idx.expand(-1, 3, -1, -1) 100 | 101 | h, w = idx.shape[2:] 102 | 103 | # Extend the dimesionality of the color scheme from 104 | # (IDX, C) to (IDX, C, H, W) and make it invariant over 105 | # the last two dimensions. 106 | colors = COLOR_SCHEME_PLASMA.unsqueeze(2).unsqueeze(3) 107 | colors = colors.expand(-1, -1, h, w) 108 | 109 | # Values in idx are somewhere between two color indices. 110 | # First generate an image based on the lower indices 111 | idx_low = idx.floor().long().clamp(0, num_colors - 1) 112 | img_low = torch.gather(colors, 0, idx_low) 113 | 114 | # Then generate an image based on the upper indices 115 | idx_high = (idx_low + 1).clamp(0, num_colors - 1) 116 | img_high = torch.gather(colors, 0, idx_high) 117 | 118 | # Then interpolate between these two 119 | sel_rel = (idx - idx_low.float()).clamp(0, 1) 120 | img = img_low + sel_rel * (img_high - img_low) 121 | 122 | return img 123 | 124 | def surface_normal_image(surface_normal): 125 | surface_normal = surface_normal.permute(0, 3, 1, 2) 126 | surface_normal = functional.pad(surface_normal, (0, 1, 0, 1), 'replicate') 127 | surface_normal = (surface_normal + 1) / 2 128 | 129 | return surface_normal 130 | -------------------------------------------------------------------------------- /colors/cityscapes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataloader.definitions.labels_file import labels_cityscape_seg 3 | 4 | # Extract the Cityscapes color scheme 5 | TRID_TO_LABEL = labels_cityscape_seg.gettrainid2label() 6 | 7 | COLOR_SCHEME_CITYSCAPES = torch.tensor( 8 | tuple( 9 | TRID_TO_LABEL[tid].color if (tid in TRID_TO_LABEL) else (0, 0, 0) 10 | for tid in range(256) 11 | ) 12 | ).float() / 255 13 | -------------------------------------------------------------------------------- /colors/plasma.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | COLOR_SCHEME_PLASMA = torch.tensor(( 4 | (0.050383, 0.029803, 0.527975), 5 | (0.063536, 0.028426, 0.533124), 6 | (0.075353, 0.027206, 0.538007), 7 | (0.086222, 0.026125, 0.542658), 8 | (0.096379, 0.025165, 0.547103), 9 | (0.105980, 0.024309, 0.551368), 10 | (0.115124, 0.023556, 0.555468), 11 | (0.123903, 0.022878, 0.559423), 12 | (0.132381, 0.022258, 0.563250), 13 | (0.140603, 0.021687, 0.566959), 14 | (0.148607, 0.021154, 0.570562), 15 | (0.156421, 0.020651, 0.574065), 16 | (0.164070, 0.020171, 0.577478), 17 | (0.171574, 0.019706, 0.580806), 18 | (0.178950, 0.019252, 0.584054), 19 | (0.186213, 0.018803, 0.587228), 20 | (0.193374, 0.018354, 0.590330), 21 | (0.200445, 0.017902, 0.593364), 22 | (0.207435, 0.017442, 0.596333), 23 | (0.214350, 0.016973, 0.599239), 24 | (0.221197, 0.016497, 0.602083), 25 | (0.227983, 0.016007, 0.604867), 26 | (0.234715, 0.015502, 0.607592), 27 | (0.241396, 0.014979, 0.610259), 28 | (0.248032, 0.014439, 0.612868), 29 | (0.254627, 0.013882, 0.615419), 30 | (0.261183, 0.013308, 0.617911), 31 | (0.267703, 0.012716, 0.620346), 32 | (0.274191, 0.012109, 0.622722), 33 | (0.280648, 0.011488, 0.625038), 34 | (0.287076, 0.010855, 0.627295), 35 | (0.293478, 0.010213, 0.629490), 36 | (0.299855, 0.009561, 0.631624), 37 | (0.306210, 0.008902, 0.633694), 38 | (0.312543, 0.008239, 0.635700), 39 | (0.318856, 0.007576, 0.637640), 40 | (0.325150, 0.006915, 0.639512), 41 | (0.331426, 0.006261, 0.641316), 42 | (0.337683, 0.005618, 0.643049), 43 | (0.343925, 0.004991, 0.644710), 44 | (0.350150, 0.004382, 0.646298), 45 | (0.356359, 0.003798, 0.647810), 46 | (0.362553, 0.003243, 0.649245), 47 | (0.368733, 0.002724, 0.650601), 48 | (0.374897, 0.002245, 0.651876), 49 | (0.381047, 0.001814, 0.653068), 50 | (0.387183, 0.001434, 0.654177), 51 | (0.393304, 0.001114, 0.655199), 52 | (0.399411, 0.000859, 0.656133), 53 | (0.405503, 0.000678, 0.656977), 54 | (0.411580, 0.000577, 0.657730), 55 | (0.417642, 0.000564, 0.658390), 56 | (0.423689, 0.000646, 0.658956), 57 | (0.429719, 0.000831, 0.659425), 58 | (0.435734, 0.001127, 0.659797), 59 | (0.441732, 0.001540, 0.660069), 60 | (0.447714, 0.002080, 0.660240), 61 | (0.453677, 0.002755, 0.660310), 62 | (0.459623, 0.003574, 0.660277), 63 | (0.465550, 0.004545, 0.660139), 64 | (0.471457, 0.005678, 0.659897), 65 | (0.477344, 0.006980, 0.659549), 66 | (0.483210, 0.008460, 0.659095), 67 | (0.489055, 0.010127, 0.658534), 68 | (0.494877, 0.011990, 0.657865), 69 | (0.500678, 0.014055, 0.657088), 70 | (0.506454, 0.016333, 0.656202), 71 | (0.512206, 0.018833, 0.655209), 72 | (0.517933, 0.021563, 0.654109), 73 | (0.523633, 0.024532, 0.652901), 74 | (0.529306, 0.027747, 0.651586), 75 | (0.534952, 0.031217, 0.650165), 76 | (0.540570, 0.034950, 0.648640), 77 | (0.546157, 0.038954, 0.647010), 78 | (0.551715, 0.043136, 0.645277), 79 | (0.557243, 0.047331, 0.643443), 80 | (0.562738, 0.051545, 0.641509), 81 | (0.568201, 0.055778, 0.639477), 82 | (0.573632, 0.060028, 0.637349), 83 | (0.579029, 0.064296, 0.635126), 84 | (0.584391, 0.068579, 0.632812), 85 | (0.589719, 0.072878, 0.630408), 86 | (0.595011, 0.077190, 0.627917), 87 | (0.600266, 0.081516, 0.625342), 88 | (0.605485, 0.085854, 0.622686), 89 | (0.610667, 0.090204, 0.619951), 90 | (0.615812, 0.094564, 0.617140), 91 | (0.620919, 0.098934, 0.614257), 92 | (0.625987, 0.103312, 0.611305), 93 | (0.631017, 0.107699, 0.608287), 94 | (0.636008, 0.112092, 0.605205), 95 | (0.640959, 0.116492, 0.602065), 96 | (0.645872, 0.120898, 0.598867), 97 | (0.650746, 0.125309, 0.595617), 98 | (0.655580, 0.129725, 0.592317), 99 | (0.660374, 0.134144, 0.588971), 100 | (0.665129, 0.138566, 0.585582), 101 | (0.669845, 0.142992, 0.582154), 102 | (0.674522, 0.147419, 0.578688), 103 | (0.679160, 0.151848, 0.575189), 104 | (0.683758, 0.156278, 0.571660), 105 | (0.688318, 0.160709, 0.568103), 106 | (0.692840, 0.165141, 0.564522), 107 | (0.697324, 0.169573, 0.560919), 108 | (0.701769, 0.174005, 0.557296), 109 | (0.706178, 0.178437, 0.553657), 110 | (0.710549, 0.182868, 0.550004), 111 | (0.714883, 0.187299, 0.546338), 112 | (0.719181, 0.191729, 0.542663), 113 | (0.723444, 0.196158, 0.538981), 114 | (0.727670, 0.200586, 0.535293), 115 | (0.731862, 0.205013, 0.531601), 116 | (0.736019, 0.209439, 0.527908), 117 | (0.740143, 0.213864, 0.524216), 118 | (0.744232, 0.218288, 0.520524), 119 | (0.748289, 0.222711, 0.516834), 120 | (0.752312, 0.227133, 0.513149), 121 | (0.756304, 0.231555, 0.509468), 122 | (0.760264, 0.235976, 0.505794), 123 | (0.764193, 0.240396, 0.502126), 124 | (0.768090, 0.244817, 0.498465), 125 | (0.771958, 0.249237, 0.494813), 126 | (0.775796, 0.253658, 0.491171), 127 | (0.779604, 0.258078, 0.487539), 128 | (0.783383, 0.262500, 0.483918), 129 | (0.787133, 0.266922, 0.480307), 130 | (0.790855, 0.271345, 0.476706), 131 | (0.794549, 0.275770, 0.473117), 132 | (0.798216, 0.280197, 0.469538), 133 | (0.801855, 0.284626, 0.465971), 134 | (0.805467, 0.289057, 0.462415), 135 | (0.809052, 0.293491, 0.458870), 136 | (0.812612, 0.297928, 0.455338), 137 | (0.816144, 0.302368, 0.451816), 138 | (0.819651, 0.306812, 0.448306), 139 | (0.823132, 0.311261, 0.444806), 140 | (0.826588, 0.315714, 0.441316), 141 | (0.830018, 0.320172, 0.437836), 142 | (0.833422, 0.324635, 0.434366), 143 | (0.836801, 0.329105, 0.430905), 144 | (0.840155, 0.333580, 0.427455), 145 | (0.843484, 0.338062, 0.424013), 146 | (0.846788, 0.342551, 0.420579), 147 | (0.850066, 0.347048, 0.417153), 148 | (0.853319, 0.351553, 0.413734), 149 | (0.856547, 0.356066, 0.410322), 150 | (0.859750, 0.360588, 0.406917), 151 | (0.862927, 0.365119, 0.403519), 152 | (0.866078, 0.369660, 0.400126), 153 | (0.869203, 0.374212, 0.396738), 154 | (0.872303, 0.378774, 0.393355), 155 | (0.875376, 0.383347, 0.389976), 156 | (0.878423, 0.387932, 0.386600), 157 | (0.881443, 0.392529, 0.383229), 158 | (0.884436, 0.397139, 0.379860), 159 | (0.887402, 0.401762, 0.376494), 160 | (0.890340, 0.406398, 0.373130), 161 | (0.893250, 0.411048, 0.369768), 162 | (0.896131, 0.415712, 0.366407), 163 | (0.898984, 0.420392, 0.363047), 164 | (0.901807, 0.425087, 0.359688), 165 | (0.904601, 0.429797, 0.356329), 166 | (0.907365, 0.434524, 0.352970), 167 | (0.910098, 0.439268, 0.349610), 168 | (0.912800, 0.444029, 0.346251), 169 | (0.915471, 0.448807, 0.342890), 170 | (0.918109, 0.453603, 0.339529), 171 | (0.920714, 0.458417, 0.336166), 172 | (0.923287, 0.463251, 0.332801), 173 | (0.925825, 0.468103, 0.329435), 174 | (0.928329, 0.472975, 0.326067), 175 | (0.930798, 0.477867, 0.322697), 176 | (0.933232, 0.482780, 0.319325), 177 | (0.935630, 0.487712, 0.315952), 178 | (0.937990, 0.492667, 0.312575), 179 | (0.940313, 0.497642, 0.309197), 180 | (0.942598, 0.502639, 0.305816), 181 | (0.944844, 0.507658, 0.302433), 182 | (0.947051, 0.512699, 0.299049), 183 | (0.949217, 0.517763, 0.295662), 184 | (0.951344, 0.522850, 0.292275), 185 | (0.953428, 0.527960, 0.288883), 186 | (0.955470, 0.533093, 0.285490), 187 | (0.957469, 0.538250, 0.282096), 188 | (0.959424, 0.543431, 0.278701), 189 | (0.961336, 0.548636, 0.275305), 190 | (0.963203, 0.553865, 0.271909), 191 | (0.965024, 0.559118, 0.268513), 192 | (0.966798, 0.564396, 0.265118), 193 | (0.968526, 0.569700, 0.261721), 194 | (0.970205, 0.575028, 0.258325), 195 | (0.971835, 0.580382, 0.254931), 196 | (0.973416, 0.585761, 0.251540), 197 | (0.974947, 0.591165, 0.248151), 198 | (0.976428, 0.596595, 0.244767), 199 | (0.977856, 0.602051, 0.241387), 200 | (0.979233, 0.607532, 0.238013), 201 | (0.980556, 0.613039, 0.234646), 202 | (0.981826, 0.618572, 0.231287), 203 | (0.983041, 0.624131, 0.227937), 204 | (0.984199, 0.629718, 0.224595), 205 | (0.985301, 0.635330, 0.221265), 206 | (0.986345, 0.640969, 0.217948), 207 | (0.987332, 0.646633, 0.214648), 208 | (0.988260, 0.652325, 0.211364), 209 | (0.989128, 0.658043, 0.208100), 210 | (0.989935, 0.663787, 0.204859), 211 | (0.990681, 0.669558, 0.201642), 212 | (0.991365, 0.675355, 0.198453), 213 | (0.991985, 0.681179, 0.195295), 214 | (0.992541, 0.687030, 0.192170), 215 | (0.993032, 0.692907, 0.189084), 216 | (0.993456, 0.698810, 0.186041), 217 | (0.993814, 0.704741, 0.183043), 218 | (0.994103, 0.710698, 0.180097), 219 | (0.994324, 0.716681, 0.177208), 220 | (0.994474, 0.722691, 0.174381), 221 | (0.994553, 0.728728, 0.171622), 222 | (0.994561, 0.734791, 0.168938), 223 | (0.994495, 0.740880, 0.166335), 224 | (0.994355, 0.746995, 0.163821), 225 | (0.994141, 0.753137, 0.161404), 226 | (0.993851, 0.759304, 0.159092), 227 | (0.993482, 0.765499, 0.156891), 228 | (0.993033, 0.771720, 0.154808), 229 | (0.992505, 0.777967, 0.152855), 230 | (0.991897, 0.784239, 0.151042), 231 | (0.991209, 0.790537, 0.149377), 232 | (0.990439, 0.796859, 0.147870), 233 | (0.989587, 0.803205, 0.146529), 234 | (0.988648, 0.809579, 0.145357), 235 | (0.987621, 0.815978, 0.144363), 236 | (0.986509, 0.822401, 0.143557), 237 | (0.985314, 0.828846, 0.142945), 238 | (0.984031, 0.835315, 0.142528), 239 | (0.982653, 0.841812, 0.142303), 240 | (0.981190, 0.848329, 0.142279), 241 | (0.979644, 0.854866, 0.142453), 242 | (0.977995, 0.861432, 0.142808), 243 | (0.976265, 0.868016, 0.143351), 244 | (0.974443, 0.874622, 0.144061), 245 | (0.972530, 0.881250, 0.144923), 246 | (0.970533, 0.887896, 0.145919), 247 | (0.968443, 0.894564, 0.147014), 248 | (0.966271, 0.901249, 0.148180), 249 | (0.964021, 0.907950, 0.149370), 250 | (0.961681, 0.914672, 0.150520), 251 | (0.959276, 0.921407, 0.151566), 252 | (0.956808, 0.928152, 0.152409), 253 | (0.954287, 0.934908, 0.152921), 254 | (0.951726, 0.941671, 0.152925), 255 | (0.949151, 0.948435, 0.152178), 256 | (0.946602, 0.955190, 0.150328), 257 | (0.944152, 0.961916, 0.146861), 258 | (0.941896, 0.968590, 0.140956), 259 | (0.940015, 0.975158, 0.131326) 260 | )) 261 | -------------------------------------------------------------------------------- /colors/tango.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | COLOR_SCHEME_TANGO = torch.tensor(( 4 | (0.937254, 0.160784, 0.160784), 5 | (0.541176, 0.886274, 0.203921), 6 | (0.447058, 0.623529, 0.811764), 7 | (0.988235, 0.913725, 0.309803), 8 | (0.678431, 0.498039, 0.658823), 9 | (0.988235, 0.686274, 0.243137), 10 | (0.913725, 0.725490, 0.431372) 11 | )) 12 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifnspaml/SGDepth/7d594b5153801d8240ec809ef6065eff1cd1f4fc/dataloader/__init__.py -------------------------------------------------------------------------------- /dataloader/data_preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifnspaml/SGDepth/7d594b5153801d8240ec809ef6065eff1cd1f4fc/dataloader/data_preprocessing/__init__.py -------------------------------------------------------------------------------- /dataloader/data_preprocessing/download_kitti.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import wget 4 | import zipfile 5 | import pandas as pd 6 | import shutil 7 | import numpy as np 8 | import glob 9 | import cv2 10 | 11 | import dataloader.file_io.get_path as gp 12 | import dataloader.file_io.dir_lister as dl 13 | from kitti_utils import pcl_to_depth_map 14 | 15 | 16 | def download_kitti_all(kitti_folder='kitti_download'): 17 | """ This pathon-script downloads all KITTI folders and aranges them in a 18 | coherent data structure which can respectively be used by the other data 19 | scripts. It is recommended to keep the standard name KITTI. Note that the 20 | path is determined automatically inside of file_io/get_path.py 21 | 22 | parameters: 23 | - kitti_folder: Name of the folder in which the dataset should be downloaded 24 | This is no path but just a name. the path ios determined by 25 | get_path.py 26 | 27 | """ 28 | 29 | # Download the standard KITTI Raw data 30 | 31 | path_getter = gp.GetPath() 32 | dataset_folder_path = path_getter.get_data_path() 33 | assert os.path.isdir(dataset_folder_path), 'Path to dataset folder does not exist' 34 | 35 | kitti_path = os.path.join(dataset_folder_path, kitti_folder) 36 | kitti_raw_data = pd.read_csv('kitti_archives_to_download.txt', 37 | header=None, delimiter=' ')[0].values 38 | kitti_path_raw = os.path.join(kitti_path, 'Raw_data') 39 | if not os.path.isdir(kitti_path_raw): 40 | os.makedirs(kitti_path_raw) 41 | for url in kitti_raw_data: 42 | folder = os.path.split(url)[1] 43 | folder = os.path.join(kitti_path_raw, folder) 44 | folder = folder[:-4] 45 | wget.download(url, out=kitti_path_raw) 46 | unzipper = zipfile.ZipFile(folder + '.zip', 'r') 47 | unzipper.extractall(kitti_path_raw) 48 | unzipper.close() 49 | os.remove(folder + '.zip') 50 | 51 | kitti_dirs_days = os.listdir(kitti_path_raw) 52 | 53 | # Get ground truth depths 54 | 55 | kitti_path_depth_annotated = os.path.join(kitti_path, 'Depth_improved') 56 | if not os.path.isdir(kitti_path_depth_annotated): 57 | os.makedirs(kitti_path_depth_annotated) 58 | url_depth_annotated = 'https://s3.eu-central-1.amazonaws.com/avg-kitti/data_depth_annotated.zip' 59 | wget.download(url_depth_annotated, out=kitti_path_depth_annotated) 60 | depth_zipped = os.path.join(kitti_path_depth_annotated, os.path.split(url_depth_annotated)[1]) 61 | unzipper = zipfile.ZipFile(depth_zipped, 'r') 62 | unzipper.extractall(kitti_path_depth_annotated) 63 | unzipper.close() 64 | os.remove(depth_zipped) 65 | 66 | trainval_folder = os.listdir(kitti_path_depth_annotated) 67 | kitti_drives_list = [] 68 | for sub_folder in trainval_folder: 69 | sub_folder = os.path.join(kitti_path_depth_annotated, sub_folder) 70 | kitti_drives_list.extend([os.path.join(sub_folder, i) for i in os.listdir(sub_folder)]) 71 | 72 | for sub_folder in kitti_dirs_days: 73 | sub_folder = os.path.join(kitti_path_depth_annotated, sub_folder) 74 | if not os.path.isdir(sub_folder): 75 | os.makedirs(sub_folder) 76 | for drive in kitti_drives_list: 77 | if os.path.split(sub_folder)[1] in drive: 78 | shutil.move(drive, sub_folder) 79 | 80 | for sub_folder in trainval_folder: 81 | sub_folder = os.path.join(kitti_path_depth_annotated, sub_folder) 82 | shutil.rmtree(sub_folder) 83 | 84 | # Get sparse depths 85 | 86 | kitti_path_depth_sparse = os.path.join(kitti_path, 'Depth_projected') 87 | if not os.path.isdir(kitti_path_depth_sparse): 88 | os.makedirs(kitti_path_depth_sparse) 89 | url_depth_sparse = 'https://s3.eu-central-1.amazonaws.com/avg-kitti/data_depth_velodyne.zip' 90 | wget.download(url_depth_sparse, out=kitti_path_depth_sparse) 91 | depth_zipped = os.path.join(kitti_path_depth_sparse, os.path.split(url_depth_sparse)[1]) 92 | unzipper = zipfile.ZipFile(depth_zipped, 'r') 93 | unzipper.extractall(kitti_path_depth_sparse) 94 | unzipper.close() 95 | os.remove(depth_zipped) 96 | 97 | trainval_folder = os.listdir(kitti_path_depth_sparse) 98 | kitti_drives_list = [] 99 | for sub_folder in trainval_folder: 100 | sub_folder = os.path.join(kitti_path_depth_sparse, sub_folder) 101 | kitti_drives_list.extend([os.path.join(sub_folder, i) for i in os.listdir(sub_folder)]) 102 | 103 | for sub_folder in kitti_dirs_days: 104 | sub_folder = os.path.join(kitti_path_depth_sparse, sub_folder) 105 | if not os.path.isdir(sub_folder): 106 | os.makedirs(sub_folder) 107 | for drive in kitti_drives_list: 108 | if os.path.split(sub_folder)[1] in drive: 109 | shutil.move(drive, sub_folder) 110 | 111 | for sub_folder in trainval_folder: 112 | sub_folder = os.path.join(kitti_path_depth_sparse, sub_folder) 113 | shutil.rmtree(sub_folder) 114 | 115 | # download test_files and integrate them into the folder structure 116 | 117 | url_depth_testset = 'https://s3.eu-central-1.amazonaws.com/avg-kitti/data_depth_selection.zip' 118 | wget.download(url_depth_testset, out=kitti_path) 119 | depth_zipped = os.path.join(kitti_path, os.path.split(url_depth_testset)[1]) 120 | unzipper = zipfile.ZipFile(depth_zipped, 'r') 121 | unzipper.extractall(kitti_path) 122 | unzipper.close() 123 | os.remove(depth_zipped) 124 | 125 | init_depth_completion_folder = os.path.join(kitti_path, 'depth_selection', 126 | 'test_depth_completion_anonymous', 'image') 127 | target_depth_completion_folder = os.path.join(kitti_path_raw, 'test_depth_completion', 'image_02') 128 | if not os.path.isdir(target_depth_completion_folder): 129 | os.makedirs(target_depth_completion_folder) 130 | shutil.move(init_depth_completion_folder, target_depth_completion_folder) 131 | os.rename(os.path.join(target_depth_completion_folder, os.path.split(init_depth_completion_folder)[1]), 132 | os.path.join(target_depth_completion_folder, 'data')) 133 | 134 | init_depth_completion_folder = os.path.join(kitti_path, 'depth_selection', 135 | 'test_depth_completion_anonymous', 'intrinsics') 136 | target_depth_completion_folder = os.path.join(kitti_path_raw, 'test_depth_completion') 137 | shutil.move(init_depth_completion_folder, target_depth_completion_folder) 138 | 139 | init_depth_completion_folder = os.path.join(kitti_path, 'depth_selection', 140 | 'test_depth_completion_anonymous', 'velodyne_raw') 141 | target_depth_completion_folder = os.path.join(kitti_path_depth_sparse, 'test_depth_completion', 'image_02') 142 | if not os.path.isdir(target_depth_completion_folder): 143 | os.makedirs(target_depth_completion_folder) 144 | shutil.move(init_depth_completion_folder, target_depth_completion_folder) 145 | os.rename(os.path.join(target_depth_completion_folder, os.path.split(init_depth_completion_folder)[1]), 146 | os.path.join(target_depth_completion_folder, 'data')) 147 | 148 | init_depth_prediction_folder = os.path.join(kitti_path, 'depth_selection', 149 | 'test_depth_prediction_anonymous', 'image') 150 | target_depth_prediction_folder = os.path.join(kitti_path_raw, 'test_depth_prediction', 'image_02') 151 | if not os.path.isdir(target_depth_prediction_folder): 152 | os.makedirs(target_depth_prediction_folder) 153 | shutil.move(init_depth_prediction_folder, target_depth_prediction_folder) 154 | os.rename(os.path.join(target_depth_prediction_folder, os.path.split(init_depth_prediction_folder)[1]), 155 | os.path.join(target_depth_prediction_folder, 'data')) 156 | 157 | init_depth_prediction_folder = os.path.join(kitti_path, 'depth_selection', 158 | 'test_depth_prediction_anonymous', 'intrinsics') 159 | target_depth_prediction_folder = os.path.join(kitti_path_raw, 'test_depth_prediction') 160 | shutil.move(init_depth_prediction_folder, target_depth_prediction_folder) 161 | 162 | shutil.rmtree(os.path.join(kitti_path, 'depth_selection')) 163 | 164 | 165 | def adjust_improvedgt_folders(kitti_folder = 'kitti_download'): 166 | """ This function adjust the format of the improved ground truth folder structure 167 | to the structure of the KITTI raw data and afterward removes the old directories. 168 | It is taken care that only the directories from the Download are worked on so that 169 | the procedure does not work on directories which it is not supposed to""" 170 | 171 | path_getter = gp.GetPath() 172 | dataset_folder_path = path_getter.get_data_path() 173 | gt_path = os.path.join(dataset_folder_path, kitti_folder) 174 | gt_path = os.path.join(gt_path, 'Depth_improved') 175 | assert os.path.isdir(gt_path), 'Path to data does not exist' 176 | folders = dl.DirLister.get_directories(gt_path) 177 | folders = dl.DirLister.include_dirs_by_name(folders, 'proj_depth') 178 | for f in folders: 179 | ground_path, camera = os.path.split(f) 180 | ground_path = os.path.split(ground_path)[0] 181 | ground_path = os.path.split(ground_path)[0] 182 | target_path = os.path.join(ground_path, camera, 'data') 183 | if not os.path.isdir(target_path): 184 | os.makedirs(target_path) 185 | else: 186 | continue 187 | for filepath in glob.glob(os.path.join(f, '*')): 188 | # Move each file to destination Directory 189 | shutil.move(filepath, target_path) 190 | print(target_path) 191 | 192 | for f in folders: 193 | remove_path = os.path.split(f)[0] 194 | remove_path = os.path.split(remove_path)[0] 195 | print(remove_path) 196 | shutil.rmtree(remove_path, ignore_errors=True) 197 | 198 | 199 | def adjust_projectedvelodyne_folders(kitti_folder='kitti_download'): 200 | """ This function adjust the format of the sparse ground truth folder structure 201 | to the structure of the KITTI raw data and afterward removes the old directories. 202 | It is taken care that only the directories from the Download are worked on so that 203 | the procedure does not work on directories which it is not supposed to""" 204 | 205 | path_getter = gp.GetPath() 206 | dataset_folder_path = path_getter.get_data_path() 207 | gt_path = os.path.join(dataset_folder_path, kitti_folder) 208 | gt_path = os.path.join(gt_path, 'Depth_projected') 209 | assert os.path.isdir(gt_path), 'Path to data does not exist' 210 | folders = dl.DirLister.get_directories(gt_path) 211 | folders = dl.DirLister.include_dirs_by_name(folders, 'proj_depth') 212 | for f in folders: 213 | ground_path, camera = os.path.split(f) 214 | ground_path = os.path.split(ground_path)[0] 215 | ground_path = os.path.split(ground_path)[0] 216 | target_path = os.path.join(ground_path, camera, 'data') 217 | if not os.path.isdir(target_path): 218 | os.makedirs(target_path) 219 | else: 220 | continue 221 | for filepath in glob.glob(os.path.join(f, '*')): 222 | # Move each file to destination Directory 223 | shutil.move(filepath, target_path) 224 | print(target_path) 225 | 226 | for f in folders: 227 | remove_path = os.path.split(f)[0] 228 | remove_path = os.path.split(remove_path)[0] 229 | print(remove_path) 230 | shutil.rmtree(remove_path, ignore_errors=True) 231 | 232 | 233 | def generate_depth_from_velo(kitti_folder='kitti_download'): 234 | """ This function generates the depth maps that correspond to the 235 | single point clouds of the raw LiDAR scans""" 236 | 237 | path_getter = gp.GetPath() 238 | dataset_folder_path = path_getter.get_data_path() 239 | gt_path = os.path.join(dataset_folder_path, kitti_folder) 240 | depth_path = os.path.join(gt_path, 'Depth') 241 | gt_path = os.path.join(gt_path, 'Raw_data') 242 | assert os.path.isdir(gt_path), 'Path to data does not exist' 243 | folders = dl.DirLister.get_directories(gt_path) 244 | folders = dl.DirLister.include_dirs_by_name(folders, 'velodyne_points') 245 | for f in folders: 246 | base_dir = os.path.split(f)[0] 247 | base_dir = os.path.split(base_dir)[0] 248 | calib_dir = os.path.split(base_dir)[0] 249 | image_dir_2 = os.path.join(base_dir, 'image_02', 'data') 250 | image_dir_3 = os.path.join(base_dir, 'image_03', 'data') 251 | day, drive = os.path.split(base_dir) 252 | day = os.path.split(day)[1] 253 | depth_dir_2 = os.path.join(depth_path, day, drive, 'image_02', 'data') 254 | depth_dir_3 = os.path.join(depth_path, day, drive, 'image_03', 'data') 255 | if not os.path.isdir(depth_dir_2): 256 | os.makedirs(depth_dir_2) 257 | if not os.path.isdir(depth_dir_3): 258 | os.makedirs(depth_dir_3) 259 | 260 | for file in glob.glob(os.path.join(f, '*')): 261 | filename = os.path.split(file)[1] 262 | filename_img = filename[:-3] + 'png' 263 | im_size_2 = cv2.imread(os.path.join(image_dir_2, filename_img)).shape[:2] 264 | im_size_3 = cv2.imread(os.path.join(image_dir_3, filename_img)).shape[:2] 265 | depth_2 = pcl_to_depth_map(calib_dir, file, im_size_2, 2) 266 | depth_3 = pcl_to_depth_map(calib_dir, file, im_size_3, 3) 267 | depth_2 = (depth_2 * 256).astype(np.uint16) 268 | depth_3 = (depth_3 * 256).astype(np.uint16) 269 | 270 | cv2.imwrite(os.path.join(depth_dir_2, filename_img), depth_2) 271 | cv2.imwrite(os.path.join(depth_dir_3, filename_img), depth_3) 272 | print(f) 273 | 274 | 275 | if __name__ == '__main__': 276 | kitti_folder = 'kitti_download' 277 | download_kitti_all(kitti_folder) 278 | adjust_improvedgt_folders(kitti_folder) 279 | adjust_projectedvelodyne_folders(kitti_folder) 280 | generate_depth_from_velo(kitti_folder) -------------------------------------------------------------------------------- /dataloader/data_preprocessing/kitti_2015_generate_depth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import sys 5 | 6 | sys.path.append('../../../') 7 | import dataloader.file_io.get_path as gp 8 | 9 | disp_names = ['disp_noc_0', 'disp_noc_1', 'disp_occ_0', 'disp_occ_1'] 10 | depth_names = ['depth_noc_0', 'depth_noc_1', 'depth_occ_0', 'depth_occ_1'] 11 | line_numbers = [20, 28, 20, 28] 12 | for disp_name, depth_name, line_number in zip(disp_names, depth_names, line_numbers): 13 | path_getter = gp.GetPath() 14 | data_path = path_getter.get_data_path() 15 | data_path = os.path.join(data_path, 'kitti_2015', 'training') 16 | disp_path = os.path.join(data_path, disp_name) 17 | depth_path = os.path.join(data_path, depth_name) 18 | if not os.path.isdir(depth_path): 19 | os.makedirs(depth_path) 20 | calib_path = os.path.join(data_path, 'calib_cam_to_cam') 21 | file_list_im = os.listdir(disp_path) 22 | file_list_cam = os.listdir(calib_path) 23 | for image_file, cam_file in zip(file_list_im, file_list_cam): 24 | im_file = os.path.join(disp_path, image_file) 25 | cam_file = os.path.join(calib_path, cam_file) 26 | disp = cv2.imread(im_file, -1).astype(np.float)/256. 27 | cam_matrix = open(cam_file).readlines()[:line_number][-1][6:].split() 28 | foc_length = (float(cam_matrix[0]) + float(cam_matrix[4]))/2.0 29 | depth = 0.54*foc_length/(disp + 0.00000000001) 30 | depth[disp == 0] = 0 31 | depth = (depth*256).astype(np.uint16) 32 | cv2.imwrite(os.path.join(depth_path, image_file), depth) 33 | -------------------------------------------------------------------------------- /dataloader/data_preprocessing/kitti_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | from collections import Counter 5 | 6 | 7 | # adapted from https://github.com/nianticlabs/monodepth2 8 | 9 | def pcl_to_depth_map(calib_dir, velo_file_name, im_shape, cam=2, vel_depth=False): 10 | #vel_depth should be False as Eigen computed the results relative to velo 11 | # load calibration files 12 | cam2cam = read_calib_file(os.path.join(calib_dir, 'calib_cam_to_cam.txt')) 13 | velo2cam = read_calib_file(os.path.join(calib_dir, 'calib_velo_to_cam.txt')) 14 | velo2cam = np.hstack((velo2cam['R'].reshape(3, 3), velo2cam['T'][..., np.newaxis])) 15 | velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0]))) 16 | 17 | # compute projection matrix velodyne->image plane 18 | R_cam2rect = np.eye(4) 19 | R_cam2rect[:3, :3] = cam2cam['R_rect_00'].reshape(3, 3) 20 | P_rect = cam2cam['P_rect_0' + str(cam)].reshape(3, 4) 21 | P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam) 22 | 23 | # load velodyne points and remove all behind image plane (approximation) 24 | # each row of the velodyne data is forward, left, up, reflectance 25 | velo = load_velodyne_points(velo_file_name) 26 | velo = velo[velo[:, 0] >= 0, :] 27 | 28 | # project the points to the camera 29 | velo_pts_im = np.dot(P_velo2im, velo.T).T 30 | velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., np.newaxis] 31 | 32 | if vel_depth: 33 | velo_pts_im[:, 2] = velo[:, 0] 34 | 35 | # check if in bounds 36 | # use minus 1 to get the exact same value as KITTI matlab code 37 | velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1 38 | velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1 39 | val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0) 40 | val_inds = val_inds & (velo_pts_im[:, 0] < im_shape[1]) & (velo_pts_im[:, 1] < im_shape[0]) 41 | velo_pts_im = velo_pts_im[val_inds, :] 42 | 43 | # project to image 44 | depth = np.zeros((im_shape)) 45 | depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2] 46 | 47 | #find the duplicate points and choose the closest depth 48 | inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0]) 49 | dupe_inds = [item for item, count in Counter(inds).items() if count > 1] 50 | for dd in dupe_inds: 51 | pts = np.where(inds == dd)[0] 52 | x_loc = int(velo_pts_im[pts[0], 0]) 53 | y_loc = int(velo_pts_im[pts[0], 1]) 54 | depth[y_loc, x_loc] = velo_pts_im[pts, 2].min() 55 | depth[depth < 0] = 0 56 | 57 | return depth 58 | 59 | 60 | # adapted from https://github.com/hunse/kitti 61 | 62 | def read_calib_file(path): 63 | float_chars = set("0123456789.e+- ") 64 | data = {} 65 | with open(path, 'r') as f: 66 | for line in f.readlines(): 67 | key, value = line.split(':', 1) 68 | value = value.strip() 69 | data[key] = value 70 | if float_chars.issuperset(value): 71 | # try to cast to float array 72 | try: 73 | data[key] = np.array(list(map(float, value.split(' ')))) 74 | except ValueError: 75 | # casting error: data[key] already eq. value, so pass 76 | pass 77 | 78 | return data 79 | 80 | 81 | # adapted from https://github.com/hunse/kitti 82 | 83 | def load_velodyne_points(file_name): 84 | points = np.fromfile(file_name, dtype=np.float32).reshape(-1, 4) 85 | points[:, 3] = 1.0 # homogeneous 86 | return points 87 | 88 | 89 | # adapted from https://github.com/nianticlabs/monodepth2 90 | 91 | def sub2ind(matrixSize, rowSub, colSub): 92 | m, n = matrixSize 93 | return rowSub * (n - 1) + colSub - 1 94 | 95 | 96 | def read_depth(filename, factor=256.): 97 | depth_png = np.array(cv2.imread(filename, -1)).astype(np.float) 98 | # make sure we have a proper 16bit depth map here.. not 8bit! 99 | assert(np.max(depth_png) > 255) 100 | 101 | depth = depth_png.astype(np.float) / factor 102 | depth[depth_png == 0] = 0 103 | return depth 104 | -------------------------------------------------------------------------------- /dataloader/definitions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifnspaml/SGDepth/7d594b5153801d8240ec809ef6065eff1cd1f4fc/dataloader/definitions/__init__.py -------------------------------------------------------------------------------- /dataloader/definitions/labels_file.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Marvin Klingner 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | 25 | from collections import namedtuple 26 | import numpy as np 27 | 28 | 29 | Label = namedtuple( 'Label' , [ 30 | 31 | 'name' , # The identifier of this label. 32 | # We use them to uniquely name a class 33 | 34 | 'id' , # An integer ID that is associated with this label. 35 | # The IDs are used to represent the label in ground truth images 36 | # An ID of -1 means that this label does not have an ID and thus 37 | # is ignored when creating ground truth images (e.g. license plate). 38 | # Do not modify these IDs, since exactly these IDs are expected by the 39 | # evaluation server. 40 | 41 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 42 | # ground truth images with train IDs, using the tools provided in the 43 | # 'preparation' folder. However, make sure to validate or submit results 44 | # to our evaluation server using the regular IDs above! 45 | # For trainIds, multiple labels might have the same ID. Then, these labels 46 | # are mapped to the same class in the ground truth images. For the inverse 47 | # mapping, we use the label that is defined first in the list below. 48 | # For example, mapping all void-type classes to the same ID in training, 49 | # might make sense for some approaches. 50 | # Max value is 255! 51 | 52 | 'category' , # The name of the category that this label belongs to 53 | 54 | 'categoryId' , # The ID of this category. Used to create ground truth images 55 | # on category level. 56 | 57 | 'hasInstances', # Whether this label distinguishes between single instances or not 58 | 59 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 60 | # during evaluations or not 61 | 62 | 'color' , # The color of this label 63 | ] ) 64 | 65 | 66 | class ClassDefinitions(object): 67 | """This class contains the classdefintions for the segmentation masks and the 68 | procedures to work with them""" 69 | 70 | def __init__(self, classlabels): 71 | self.labels = classlabels 72 | for i, label in zip(range(len(self.labels)), self.labels): 73 | if isinstance(label.color, int): 74 | self.labels[i] = label._replace(color=tuple([int(label.color/(256.0**2)) % 256, 75 | int(label.color/256.0) % 256, 76 | int(label.color) % 256])) 77 | 78 | def getlabels(self): 79 | return self.labels 80 | 81 | def getname2label(self): 82 | name2label = {label.name: label for label in self.labels} 83 | return name2label 84 | 85 | def getid2label(self): 86 | id2label = {label.id: label for label in self.labels} 87 | return id2label 88 | 89 | def gettrainid2label(self): 90 | trainid2label = {label.trainId: label for label in reversed(self.labels)} 91 | return trainid2label 92 | 93 | def getcategory2label(self): 94 | category2labels = {} 95 | for label in self.labels: 96 | category = label.category 97 | if category in category2labels: 98 | category2labels[category].append(label) 99 | else: 100 | category2labels[category] = [label] 101 | 102 | def assureSingleInstanceName(self,name): 103 | # if the name is known, it is not a group 104 | name2label = self.getname2label() 105 | if name in name2label: 106 | return name 107 | # test if the name actually denotes a group 108 | if not name.endswith("group"): 109 | return None 110 | # remove group 111 | name = name[:-len("group")] 112 | # test if the new name exists 113 | if not name in name2label: 114 | return None 115 | # test if the new name denotes a label that actually has instances 116 | if not name2label[name].hasInstances: 117 | return None 118 | # all good then 119 | return name 120 | 121 | 122 | labels_cityscape_seg = ClassDefinitions([ 123 | # name id trainId category catId hasInstances ignoreInEval color 124 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 125 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 126 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 127 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 128 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 129 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 130 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 131 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 132 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 133 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 134 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 135 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 136 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 137 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 138 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 139 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 140 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 141 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 142 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), 143 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 144 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 145 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 146 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 147 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 148 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 149 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 150 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 151 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 152 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 153 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 154 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 155 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 156 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 157 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 158 | Label( 'license plate' , -1 , 255 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), 159 | ]) 160 | 161 | 162 | dataset_labels = { 163 | 'cityscapes': labels_cityscape_seg, 164 | } 165 | -------------------------------------------------------------------------------- /dataloader/eval/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dataloader/eval/metrics.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Marvin Klingner 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import numpy as np 24 | import warnings 25 | 26 | 27 | # from https://github.com/tinghuiz/SfMLearner 28 | def dump_xyz(source_to_target_transformations): 29 | xyzs = [] 30 | cam_to_world = np.eye(4) 31 | xyzs.append(cam_to_world[:3, 3]) 32 | for source_to_target_transformation in source_to_target_transformations: 33 | cam_to_world = np.dot(cam_to_world, source_to_target_transformation) 34 | xyzs.append(cam_to_world[:3, 3]) 35 | return xyzs 36 | 37 | 38 | # from https://github.com/tinghuiz/SfMLearner 39 | def compute_ate(gtruth_xyz, pred_xyz_o): 40 | 41 | # Make sure that the first matched frames align (no need for rotational alignment as 42 | # all the predicted/ground-truth snippets have been converted to use the same coordinate 43 | # system with the first frame of the snippet being the origin). 44 | offset = gtruth_xyz[0] - pred_xyz_o[0] 45 | pred_xyz = pred_xyz_o + offset[None, :] 46 | 47 | # Optimize the scaling factor 48 | scale = np.sum(gtruth_xyz * pred_xyz) / np.sum(pred_xyz ** 2) 49 | alignment_error = pred_xyz * scale - gtruth_xyz 50 | rmse = np.sqrt(np.sum(alignment_error ** 2)) / gtruth_xyz.shape[0] 51 | return rmse 52 | 53 | 54 | class Evaluator(object): 55 | # CONF MATRIX 56 | # 0 1 2 (PRED) 57 | # 0 |TP FN FN| 58 | # 1 |FP TP FN| 59 | # 2 |FP FP TP| 60 | # (GT) 61 | # -> rows (axis=1) are FN 62 | # -> columns (axis=0) are FP 63 | @staticmethod 64 | def iou(conf): # TP / (TP + FN + FP) 65 | with warnings.catch_warnings(): 66 | warnings.filterwarnings('ignore') 67 | iu = np.diag(conf) / (conf.sum(axis=1) + conf.sum(axis=0) - np.diag(conf)) 68 | meaniu = np.nanmean(iu) 69 | result = {'iou': dict(zip(range(len(iu)), iu)), 'meaniou': meaniu} 70 | return result 71 | 72 | @staticmethod 73 | def accuracy(conf): # TP / (TP + FN) aka 'Recall' 74 | # Add 'add' in order to avoid division by zero and consequently NaNs in iu 75 | with warnings.catch_warnings(): 76 | warnings.filterwarnings('ignore') 77 | totalacc = np.diag(conf).sum() / (conf.sum()) 78 | acc = np.diag(conf) / (conf.sum(axis=1)) 79 | meanacc = np.nanmean(acc) 80 | result = {'totalacc': totalacc, 'meanacc': meanacc, 'acc': acc} 81 | return result 82 | 83 | @staticmethod 84 | def precision(conf): # TP / (TP + FP) 85 | # Add 'add' in order to avoid division by zero and consequently NaNs in iu 86 | with warnings.catch_warnings(): 87 | warnings.filterwarnings('ignore') 88 | prec = np.diag(conf) / (conf.sum(axis=0)) 89 | meanprec = np.nanmean(prec) 90 | result = {'meanprec': meanprec, 'prec': prec} 91 | return result 92 | 93 | @staticmethod 94 | def freqwacc(conf): 95 | # Add 'add' in order to avoid division by zero and consequently NaNs in iu 96 | with warnings.catch_warnings(): 97 | warnings.filterwarnings('ignore') 98 | iu = np.diag(conf) / (conf.sum(axis=1) + conf.sum(axis=0) - np.diag(conf)) 99 | freq = conf.sum(axis=1) / (conf.sum()) 100 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 101 | result = {'freqwacc': fwavacc} 102 | return result 103 | 104 | @staticmethod 105 | def depththresh(gt, pred): 106 | thresh = np.maximum((gt / pred), (pred / gt)) 107 | a1 = (thresh < 1.25).mean() 108 | a2 = (thresh < 1.25 ** 2).mean() 109 | a3 = (thresh < 1.25 ** 3).mean() 110 | 111 | result = {'delta1': a1, 'delta2': a2, 'delta3': a3} 112 | return result 113 | 114 | @staticmethod 115 | def deptherror(gt, pred): 116 | rmse = (gt - pred) ** 2 117 | rmse = np.sqrt(rmse.mean()) 118 | rmse_log = (np.log(gt) - np.log(pred)) ** 2 119 | rmse_log = np.sqrt(rmse_log.mean()) 120 | abs_rel = np.mean(np.abs(gt - pred) / gt) 121 | sq_rel = np.mean(((gt - pred) ** 2) / gt) 122 | 123 | result = {'abs_rel': abs_rel, 'sq_rel': sq_rel, 'rmse': rmse, 'rmse_log': rmse_log} 124 | return result 125 | 126 | 127 | class SegmentationRunningScore(object): 128 | def __init__(self, n_classes=20): 129 | self.n_classes = n_classes 130 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 131 | 132 | def _fast_hist(self, label_true, label_pred, n_class): 133 | mask_true = (label_true >= 0) & (label_true < n_class) 134 | mask_pred = (label_pred >= 0) & (label_pred < n_class) 135 | mask = mask_pred & mask_true 136 | label_true = label_true[mask].astype(np.int) 137 | label_pred = label_pred[mask].astype(np.int) 138 | hist = np.bincount(n_class * label_true + label_pred, 139 | minlength=n_class*n_class).reshape(n_class, n_class).astype(np.float) 140 | return hist 141 | 142 | def update(self, label_trues, label_preds): 143 | # label_preds = label_preds.exp() 144 | # label_preds = label_preds.argmax(1).cpu().numpy() # filter out the best projected class for each pixel 145 | # label_trues = label_trues.numpy() # convert to numpy array 146 | 147 | for lt, lp in zip(label_trues, label_preds): 148 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) #update confusion matrix 149 | 150 | def get_scores(self, listofparams=None): 151 | """Returns the evaluation params specified in the list""" 152 | possibleparams = { 153 | 'iou': Evaluator.iou, 154 | 'acc': Evaluator.accuracy, 155 | 'freqwacc': Evaluator.freqwacc, 156 | 'prec': Evaluator.precision 157 | } 158 | if listofparams is None: 159 | listofparams = possibleparams 160 | 161 | result = {} 162 | for param in listofparams: 163 | if param in possibleparams.keys(): 164 | result.update(possibleparams[param](self.confusion_matrix)) 165 | return result 166 | 167 | def reset(self): 168 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 169 | 170 | 171 | class DepthRunningScore(object): 172 | def __init__(self): 173 | self.num_samples = 0 174 | self.depth_thresh = {'delta1': 0, 'delta2': 0, 'delta3': 0} 175 | self.depth_errors = {'abs_rel': 0, 'sq_rel': 0, 'rmse': 0, 'rmse_log': 0} 176 | 177 | def update(self, ground_truth, prediction): 178 | if isinstance(ground_truth, list): 179 | self.num_samples += len(ground_truth) 180 | else: 181 | ground_truth = [ground_truth] 182 | prediction = [prediction] 183 | self.num_samples += 1 184 | 185 | for k in range(len(ground_truth)): 186 | gt = ground_truth[k].astype(np.float) 187 | pred = prediction[k].astype(np.float) 188 | thresh = Evaluator.depththresh(gt, pred) 189 | error = Evaluator.deptherror(gt, pred) 190 | for i, j in zip(thresh.keys(), self.depth_thresh.keys()): 191 | self.depth_thresh[i] += thresh[j] 192 | for i, j in zip(error.keys(), self.depth_errors.keys()): 193 | self.depth_errors[i] += error[j] 194 | 195 | def get_scores(self, listofparams=None): 196 | """Returns the evaluation params specified in the list""" 197 | possibleparams = { 198 | 'thresh': self.depth_thresh, 199 | 'error': self.depth_errors, 200 | } 201 | if listofparams is None: 202 | listofparams = possibleparams 203 | 204 | result = {} 205 | for param in listofparams: 206 | if param in possibleparams.keys(): 207 | result.update(possibleparams[param]) 208 | for i in result.keys(): 209 | result[i] = result[i]/self.num_samples 210 | 211 | return result 212 | 213 | def reset(self): 214 | self.num_samples = 0 215 | self.depth_thresh = {'delta1': 0, 'delta2': 0, 'delta3': 0} 216 | self.depth_errors = {'abs_rel': 0, 'sq_rel': 0, 'rmse': 0, 'rmse_log': 0} 217 | 218 | 219 | class PoseRunningScore(object): 220 | def __init__(self): 221 | self.preds = list() 222 | self.gts = list() 223 | 224 | def update(self, ground_truth, prediction): 225 | if isinstance(ground_truth, list): 226 | self.gts += ground_truth 227 | else: 228 | self.gts += [ground_truth] 229 | 230 | if isinstance(prediction, list): 231 | self.preds += prediction 232 | else: 233 | self.preds += [prediction] 234 | 235 | def get_scores(self): 236 | """Returns the evaluation params specified in the list""" 237 | 238 | gt_global_poses = np.concatenate(self.gts) 239 | pred_poses = np.concatenate(self.preds) 240 | 241 | gt_global_poses = np.concatenate( 242 | (gt_global_poses, np.zeros((gt_global_poses.shape[0], 1, 4))), 1) 243 | gt_global_poses[:, 3, 3] = 1 244 | gt_xyzs = gt_global_poses[:, :3, 3] 245 | gt_local_poses = [] 246 | for i in range(1, len(gt_global_poses)): 247 | gt_local_poses.append( 248 | np.linalg.inv(np.dot(np.linalg.inv(gt_global_poses[i - 1]), gt_global_poses[i]))) 249 | ates = [] 250 | num_frames = gt_xyzs.shape[0] 251 | track_length = 5 252 | for i in range(0, num_frames - track_length + 1): 253 | local_xyzs = np.array(dump_xyz(pred_poses[i:i + track_length - 1])) 254 | gt_local_xyzs = np.array(dump_xyz(gt_local_poses[i:i + track_length - 1])) 255 | ates.append(compute_ate(gt_local_xyzs, local_xyzs)) 256 | 257 | pose_error = {'mean': np.mean(ates), 'std': np.std(ates)} 258 | return pose_error 259 | 260 | def reset(self): 261 | self.preds = list() 262 | self.gts = list() 263 | 264 | 265 | class AverageMeter(object): 266 | """Computes and stores the average and current value""" 267 | 268 | def __init__(self): 269 | self.reset() 270 | 271 | def reset(self): 272 | self.val = 0 273 | self.avg = 0 274 | self.sum = 0 275 | self.count = 0 276 | 277 | def update(self, val, n=1): 278 | self.val = val 279 | self.sum += val * n 280 | self.count += n 281 | self.avg = self.sum / self.count 282 | -------------------------------------------------------------------------------- /dataloader/file_io/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifnspaml/SGDepth/7d594b5153801d8240ec809ef6065eff1cd1f4fc/dataloader/file_io/__init__.py -------------------------------------------------------------------------------- /dataloader/file_io/dir_lister.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Marvin Klingner 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import os 24 | import sys 25 | 26 | import dataloader.file_io.get_path as gp 27 | 28 | 29 | class DirLister: 30 | """ This class will provide methods that enable the creation 31 | file lists and may return them in desired formats""" 32 | def __init__(self): 33 | pass 34 | 35 | @staticmethod 36 | def check_formats(cur_dir=None, file_ending=None): 37 | """ method to check if specified parameters have the right format 38 | 39 | :param cur_dir: directory which is checked for existance 40 | :param file_ending: file ending that is checked for the right format 41 | """ 42 | check = True 43 | if cur_dir is not None: 44 | if os.path.isdir(cur_dir) is False: 45 | print("the specified directory does not exist") 46 | check = False 47 | if file_ending is not None: 48 | if file_ending[0] != '.': 49 | print("the file ending has no '.' at the beginning") 50 | check = False 51 | return check 52 | 53 | @staticmethod 54 | def list_subdirectories(top_dir): 55 | """ method that lists all subdirectories of a given directory 56 | 57 | :param top_dir: directory in which the subdirectories are searched in 58 | """ 59 | top_dir = os.path.abspath(top_dir) 60 | sub_dirs = [os.path.join(top_dir, x) for x in os.listdir(top_dir) 61 | if os.path.isdir(os.path.join(top_dir, x))] 62 | return sub_dirs 63 | 64 | @staticmethod 65 | def list_files_in_directory(top_dir): 66 | """ method that lists all files of a given directory 67 | 68 | :param top_dir: directory in which the files are searched in 69 | """ 70 | top_dir = os.path.abspath(top_dir) 71 | files = [os.path.join(top_dir, x) for x in os.listdir(top_dir) 72 | if os.path.isfile(os.path.join(top_dir, x))] 73 | return files 74 | 75 | @staticmethod 76 | def get_directories(parent_dir): 77 | """ method that lists all directories of a given directory recursively 78 | 79 | :param parent_dir: directory in which the subdirectories are searched in 80 | """ 81 | if DirLister.check_formats(cur_dir=parent_dir) is False: 82 | sys.exit("Inputparameter überprüfen") 83 | parent_dir = os.path.abspath(parent_dir) 84 | sub_dirs = [] 85 | still_to_search = DirLister.list_subdirectories(parent_dir) 86 | 87 | while len(still_to_search) > 0: 88 | curr_sub_dirs = DirLister.list_subdirectories(still_to_search[0]) 89 | if len(curr_sub_dirs) == 0: 90 | sub_dirs.append(still_to_search[0]) 91 | else: 92 | still_to_search.extend(curr_sub_dirs) 93 | still_to_search.remove(still_to_search[0]) 94 | 95 | return sub_dirs 96 | 97 | @staticmethod 98 | def include_files_by_name(file_list, names, positions): 99 | """ takes a list of filepaths and keeps only the files which have all strings 100 | inside of their path specified by the list names 101 | 102 | :param file_list: list of filepaths 103 | :param names: strings which have to be inside the directory name 104 | :param positions: positions inside the dataset which are also only kept if the element is kept 105 | """ 106 | if type(names) == list: 107 | for name in names: 108 | positions = [positions[i] for i in range(len(file_list)) if name in file_list[i]] 109 | file_list = [x for x in file_list if name in x] 110 | elif type(names) == str: 111 | name = names 112 | positions = [positions[i] for i in range(len(file_list)) if name in file_list[i]] 113 | file_list = [x for x in file_list if name in x] 114 | return file_list, positions 115 | 116 | @staticmethod 117 | def include_files_by_folder(file_list, names, positions): 118 | """ takes a list of filepaths and keeps only the files which have all strings 119 | inside of their path specified by the list names 120 | 121 | :param file_list: list of filepaths 122 | :param names: folders which have to be inside the directory path 123 | :param positions: positions inside the dataset which are also only kept if the element is kept 124 | """ 125 | if type(names) == list: 126 | for name in names: 127 | positions = [positions[i] for i in range(len(file_list)) if name in file_list[i]] 128 | file_list = [x for x in file_list if name + os.sep in x or name == os.path.split(x)[1]] 129 | elif type(names) == str: 130 | name = names 131 | positions = [positions[i] for i in range(len(file_list)) if name in file_list[i]] 132 | file_list = [x for x in file_list if name + os.sep in x or name == os.path.split(x)[1]] 133 | return file_list, positions 134 | 135 | @staticmethod 136 | def include_dirs_by_name(dir_list, names, ignore=(), ambiguous_names_to_ignore=()): 137 | """ takes a list of directories and includes the directories which have all strings 138 | of the ones specified by the list names 139 | 140 | :param dir_list: list of directories 141 | :param names: strings which have to be inside the directory name 142 | :param ignore: string that must not be inside the directory name 143 | :param ambiguous_names_to_ignore: A list containing all strings that should not be taken into account when 144 | comparing to names. For example, if an upper folder is called 'dataset_images' and one filter name 145 | is also 'images' (e.g. for the color image), then this parameter will prevent all folder from being 146 | returned 147 | :return: a list of all folders containing all names, excluding those containing a string in ignore 148 | """ 149 | shortened_dir_list = dir_list.copy() 150 | if type(ambiguous_names_to_ignore) == str: 151 | ambiguous_names_to_ignore = (ambiguous_names_to_ignore, ) 152 | for ambiguous_name in ambiguous_names_to_ignore: 153 | shortened_dir_list = [x.replace(ambiguous_name, '') for x in shortened_dir_list] 154 | if type(names) == list: 155 | for name in names: 156 | dir_list = [x for x, xs in zip(dir_list, shortened_dir_list) if name in xs] 157 | shortened_dir_list = [xs for xs in shortened_dir_list if name in xs] 158 | elif type(names) == str: 159 | name = names 160 | dir_list = [x for x, xs in zip(dir_list, shortened_dir_list) if name in xs] 161 | for ignore_string in ignore: 162 | dir_list = [x for x in dir_list if ignore_string not in x] 163 | return dir_list 164 | 165 | @staticmethod 166 | def include_dirs_by_folder(dir_list, names): 167 | """ takes a list of directories and includes the directories which have all strings 168 | of the ones specified by the list names 169 | 170 | :param dir_list: list of directories 171 | :param names: folders which have to be inside the directory path 172 | """ 173 | if type(names) == list: 174 | for name in names: 175 | dir_list = [x for x in dir_list if name + os.sep in x or name == os.path.split(x)[1]] 176 | elif type(names) == str: 177 | name = names 178 | dir_list = [x for x in dir_list if name + os.sep in x or name == os.path.split(x)[1]] 179 | return dir_list 180 | 181 | @staticmethod 182 | def remove_dirs_by_name(dir_list, names): 183 | """ takes a list of directories and removes the directories which have at least one string 184 | of the ones specified by the list names 185 | 186 | :param dir_list: list of directories 187 | :param names: strings which are not allowed inside the directory name 188 | """ 189 | if type(names) == list: 190 | for name in names: 191 | dir_list = [x for x in dir_list if name not in x] 192 | elif type(names) == str: 193 | name = names 194 | dir_list = [x for x in dir_list if name not in x] 195 | return dir_list 196 | 197 | @staticmethod 198 | def get_files_by_ending(cur_dir, file_ending, ignore = []): 199 | """ returns all files inside a directory which have a certain ending 200 | 201 | :param cur_dir: list of directories 202 | :param file_ending: all files with the specified file_ending are returned 203 | :param ignore: list of strings. Filenames containing one of these strings will be ignored. 204 | :return: all files inside cur_dir which have the ending file_ending 205 | """ 206 | if DirLister.check_formats(cur_dir=cur_dir, 207 | file_ending=file_ending) is False: 208 | sys.exit("Inputparameter überprüfen") 209 | files = DirLister.list_files_in_directory(cur_dir) 210 | len_ending = len(file_ending) 211 | files = [x for x in files if x[-len_ending:] == file_ending] 212 | for ignore_string in ignore: 213 | files = [x for x in files if ignore_string not in x] 214 | return files 215 | 216 | 217 | if __name__ == '__main__': 218 | """can be used for testing purposes""" 219 | path_getter = gp.GetPath() 220 | path = path_getter.get_data_path() 221 | path = os.path.join(path, 'Cityscapes') 222 | a = DirLister() 223 | test = a.get_directories(path) 224 | print(a.include_dirs_by_name(test, 'test')) 225 | -------------------------------------------------------------------------------- /dataloader/file_io/get_path.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Marvin Klingner 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import os 24 | import platform 25 | import socket 26 | import json 27 | 28 | class GetPath: 29 | def __init__(self): 30 | """This class gives the paths that are needed for training and testing neural networks. 31 | 32 | Paths that need to be specified are the data path and the checkpoint path, where the models will be saved. 33 | The paths have to saved in environment variables called IFN_DIR_DATASET and IFN_DIR_CHECKPOINT, respectively. 34 | """ 35 | 36 | # Check if the user did explicitly set environment variables 37 | if self._guess_by_env(): 38 | return 39 | 40 | # Print a helpful text when no directories could be found 41 | if platform.system() == 'Windows': 42 | raise ValueError( 43 | 'Could not determine dataset/checkpoint directory. ' 44 | 'You can use environment variables to specify these directories ' 45 | 'by using the following commands:\n' 46 | 'setx IFN_DIR_DATASET \n' 47 | 'setx IFN_DIR_CHECKPOINT \n' 48 | ) 49 | 50 | else: 51 | raise ValueError( 52 | 'Could not determine dataset/checkpoint directory. ' 53 | 'You can use environment variables to specify these directories ' 54 | 'by adding lines like the following to your ~/.bashrc:\n' 55 | 'export IFN_DIR_DATASET=\n' 56 | 'export IFN_DIR_CHECKPOINT=' 57 | ) 58 | 59 | def _check_dirs(self): 60 | if self.dataset_base_path is None: 61 | return False 62 | 63 | if self.checkpoint_base_path is None: 64 | return False 65 | 66 | if not os.path.isdir(self.dataset_base_path): 67 | return False 68 | 69 | return True 70 | 71 | def _guess_by_env(self): 72 | dataset_base = os.environ.get('IFN_DIR_DATASET', None) 73 | checkpoint_base = os.environ.get('IFN_DIR_CHECKPOINT', None) 74 | 75 | self.dataset_base_path = dataset_base 76 | self.checkpoint_base_path = checkpoint_base 77 | 78 | return self._check_dirs() 79 | 80 | def get_data_path(self): 81 | """returns the path to the dataset folder""" 82 | 83 | return self.dataset_base_path 84 | 85 | def get_checkpoint_path(self): 86 | """returns the path to the checkpoints of the models""" 87 | 88 | return self.checkpoint_base_path 89 | -------------------------------------------------------------------------------- /dataloader/pt_data_loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifnspaml/SGDepth/7d594b5153801d8240ec809ef6065eff1cd1f4fc/dataloader/pt_data_loader/__init__.py -------------------------------------------------------------------------------- /dataloader/pt_data_loader/dataset_parameterset.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Marvin Klingner 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import os 24 | import sys 25 | import json 26 | import numpy as np 27 | 28 | import dataloader.file_io.get_path as gp 29 | import dataloader.definitions.labels_file as lf 30 | 31 | 32 | class DatasetParameterset: 33 | """A class that contains all dataset-specific parameters 34 | 35 | - K: Extrinsic camera matrix as a Numpy array. If not available, take None 36 | - stereo_T: Distance between the two cameras (see e.g. http://www.cvlibs.net/datasets/kitti/setup.php, 0.54m) 37 | - labels: 38 | - labels_mode: 'fromid' or 'fromrgb', depending on which format the segmentation images have 39 | - depth_mode: 'uint_16' or 'uint_16_subtract_one' depending on which format the depth images have 40 | - flow_mode: specifies how the flow images are stored, e.g. 'kitti' 41 | - splits: List of splits that are available for this dataset 42 | """ 43 | def __init__(self, dataset): 44 | path_getter = gp.GetPath() 45 | dataset_folder = path_getter.get_data_path() 46 | path = os.path.join(dataset_folder, dataset, 'parameters.json') 47 | if not os.path.isdir(os.path.join(dataset_folder, dataset)): 48 | raise Exception('There is no dataset folder called {}'.format(dataset)) 49 | if not os.path.isfile(path): 50 | raise Exception('There is no parameters.json file in the dataset folder. Please create it using the ' 51 | 'dataset_index.py in the folder dataloader/file_io in order to load this dataset') 52 | with open(path) as file: 53 | param_dict = json.load(file) 54 | self._dataset = dataset 55 | self._K = param_dict['K'] 56 | if self._K is not None: 57 | self._K = np.array(self._K, dtype=np.float32) 58 | if param_dict['stereo_T'] is not None: 59 | self._stereo_T = np.eye(4, dtype=np.float32) 60 | self._stereo_T[0, 3] = param_dict['stereo_T'] 61 | else: 62 | self._stereo_T = None 63 | self._depth_mode = param_dict['depth_mode'] 64 | self._flow_mode = param_dict['flow_mode'] 65 | self._splits = param_dict['splits'] 66 | labels_name = param_dict['labels'] 67 | if labels_name in lf.dataset_labels.keys(): 68 | self.labels = lf.dataset_labels[labels_name].getlabels() 69 | self.labels_mode = param_dict['labels_mode'] 70 | else: 71 | self.labels = None 72 | self.labels_mode = None 73 | 74 | @property 75 | def dataset(self): 76 | return self._dataset 77 | 78 | @property 79 | def K(self): 80 | return self._K 81 | 82 | @property 83 | def stereo_T(self): 84 | return self._stereo_T 85 | 86 | @property 87 | def depth_mode(self): 88 | return self._depth_mode 89 | 90 | @property 91 | def flow_mode(self): 92 | return self._flow_mode 93 | 94 | @property 95 | def splits(self): 96 | return self._splits 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /dataloader/pt_data_loader/specialdatasets.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Marvin Klingner 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from __future__ import absolute_import, division, print_function 24 | 25 | import sys 26 | import numpy as np 27 | import matplotlib.pyplot as plt 28 | from torch.utils.data import DataLoader 29 | 30 | from dataloader.pt_data_loader.basedataset import BaseDataset 31 | import dataloader.pt_data_loader.mytransforms as mytransforms 32 | import dataloader.definitions.labels_file as lf 33 | 34 | 35 | class StandardDataset(BaseDataset): 36 | def __init__(self, *args, **kwargs): 37 | super(StandardDataset, self).__init__(*args, **kwargs) 38 | 39 | if self.disable_const_items is False: 40 | assert self.parameters.K is not None and self.parameters.stereo_T is not None, '''There are no K matrix and 41 | stereo_T parameter available for this dataset.''' 42 | 43 | def add_const_dataset_items(self, sample): 44 | K = self.parameters.K.copy() 45 | 46 | native_key = ('color', 0, -1) if (('color', 0, -1) in sample) else ('color_right', 0, -1) 47 | native_im_shape = sample[native_key].shape 48 | 49 | K[0, :] *= native_im_shape[1] 50 | K[1, :] *= native_im_shape[0] 51 | 52 | sample["K", -1] = K 53 | sample["stereo_T"] = self.parameters.stereo_T 54 | 55 | return sample -------------------------------------------------------------------------------- /dc_masking.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import numpy as np 3 | 4 | import torch 5 | 6 | from dataloader.eval.metrics import SegmentationRunningScore 7 | 8 | 9 | class DCMasking(object): 10 | def __init__(self, masking_from_epoch, num_epochs, moving_mask_percent, masking_linear_increase): 11 | self.masking_from_epoch = masking_from_epoch 12 | self.num_epochs = num_epochs 13 | self.moving_mask_percent = moving_mask_percent 14 | self.masking_linear_increase = masking_linear_increase 15 | 16 | self.segmentation_input_key = ('color_aug', 0, 0) 17 | self.logits_key = ('segmentation_logits', 0) 18 | 19 | self.metric_model_moving = SegmentationRunningScore(2) 20 | 21 | self.iou_thresh = dict() 22 | self.iou_thresh['non_moving'] = 0.0 23 | self.iou_thresh['moving'] = 0.0 24 | 25 | self.iou_log = dict() 26 | self.iou_log['non_moving'] = list() 27 | self.iou_log['moving'] = list() 28 | 29 | def _moving_class_criterion(self, segmentation): 30 | # TODO this is valid for the Cityscapes class definitions and has to be adapted for other datasets 31 | # to be more generic 32 | mask = (segmentation > 10) & (segmentation < 100) 33 | return mask 34 | 35 | def compute_segmentation_frames(self, batch, model): 36 | batch_masking = deepcopy(batch) 37 | 38 | # get the depth indices 39 | batch_indices = tuple([idx_batch for idx_batch, sub_batch in enumerate(batch_masking) 40 | if any('depth' in purpose_tuple 41 | for purpose_tuple in sub_batch['purposes']) 42 | ]) 43 | 44 | # get the depth images 45 | batch_masking = tuple([sub_batch for sub_batch in batch_masking 46 | if any('depth' in purpose_tuple 47 | for purpose_tuple in sub_batch['purposes']) 48 | ]) 49 | 50 | # replace the purpose to segmentation 51 | for idx1, sub_batch in enumerate(batch_masking): 52 | for idx2, purpose_tuple in enumerate(sub_batch['purposes']): 53 | batch_masking[idx1]['purposes'][idx2] = tuple([purpose.replace('depth', 'segmentation') 54 | for purpose in purpose_tuple]) 55 | 56 | # generate the correct keys and outputs 57 | input_image_keys = [key for key in batch_masking[0].keys() if 'color_aug' in key] 58 | output_segmentation_keys = [('segmentation', key[1], key[2]) for key in input_image_keys] 59 | outputs_masked = list(dict() for i in range(len(batch))) 60 | 61 | # pass all depth image frames through the network to get the segmentation outputs 62 | for in_key, out_key in zip(input_image_keys, output_segmentation_keys): 63 | wanted_keys = ['domain', 'purposes', 'domain_idx', in_key] 64 | batch_masking_key = deepcopy(batch_masking) 65 | batch_masking_key = tuple([{key: sub_batch[key] for key in sub_batch.keys() 66 | if key in wanted_keys} 67 | for sub_batch in batch_masking_key]) 68 | for idx1 in range(len(batch_masking_key)): 69 | batch_masking_key[idx1][self.segmentation_input_key] = \ 70 | batch_masking_key[idx1][in_key].clone() 71 | if in_key != self.segmentation_input_key: 72 | del batch_masking_key[idx1][in_key] 73 | 74 | outputs_masked_key = model(batch_masking_key) 75 | cur_idx_outputs = 0 76 | for idx_batch in range(len(outputs_masked)): 77 | if idx_batch in batch_indices: 78 | outputs_masked[idx_batch][out_key] = outputs_masked_key[cur_idx_outputs][self.logits_key].argmax(1) 79 | cur_idx_outputs += 1 80 | else: 81 | outputs_masked[idx_batch] = None 82 | 83 | outputs_masked = tuple(outputs_masked) 84 | return outputs_masked 85 | 86 | def compute_moving_mask(self, output_masked): 87 | """Compute moving mask and iou 88 | """ 89 | segmentation = output_masked[("segmentation", 0, 0)] 90 | # Create empty mask 91 | moving_mask_combined = torch.zeros(segmentation.shape).to(segmentation.device) 92 | # Create binary mask moving in t = 0, movable object = 1, non_movable = 0 93 | 94 | # Create binary masks (moving / non-moving) 95 | moving_mask = dict() 96 | moving_mask[0] = self._moving_class_criterion(segmentation).float() 97 | for key in output_masked.keys(): 98 | if key[0] == "segmentation_warped": 99 | moving_mask[key[1]] = self._moving_class_criterion(output_masked[("segmentation_warped", key[1], 0)]) 100 | 101 | # Calculate IoU for each frame separately 102 | for i in range(moving_mask[0].shape[0]): 103 | 104 | # Average score over frames 105 | for frame_id in moving_mask.keys(): 106 | if frame_id == 0: 107 | continue 108 | # For binary class 109 | self.metric_model_moving.update( 110 | np.array(moving_mask[frame_id][i].cpu()), np.array(moving_mask[0][i].cpu())) 111 | 112 | scores = self.metric_model_moving.get_scores() 113 | 114 | if not np.isnan(scores['iou'][0]): 115 | self.iou_log['non_moving'].append(scores['iou'][0]) 116 | if not np.isnan(scores['iou'][1]): 117 | self.iou_log['moving'].append(scores['iou'][1]) 118 | # Calculate Mask if scores of moving objects is not NaN 119 | # mask every moving class, were the iou is smaller than threshold 120 | if scores['iou'][1] < self.iou_thresh['moving']: 121 | # Add moving mask of t = 0 122 | moving_mask_combined[i] += self._moving_class_criterion(segmentation[i]).float() 123 | # Add moving mask of segmentation mask of t!=0 warped to t=0 124 | for frame_id in moving_mask.keys(): 125 | if frame_id == 0: 126 | continue 127 | moving_mask_combined[i] += self._moving_class_criterion( 128 | output_masked[("segmentation_warped", frame_id, 0)][i]).float() 129 | # mask moving in t != 0 130 | self.metric_model_moving.reset() 131 | # movable object = 0, non_movable = 1 132 | output_masked['moving_mask'] = (moving_mask_combined < 1).float().detach() 133 | 134 | def clear_iou_log(self): 135 | self.iou_log = dict() 136 | self.iou_log['non_moving'] = list() 137 | self.iou_log['moving'] = list() 138 | 139 | def calculate_iou_threshold(self, current_epoch): 140 | if self.masking_from_epoch <= current_epoch: 141 | self.iou_thresh = dict() 142 | if self.masking_linear_increase: 143 | percentage = 1 - (1 / (self.num_epochs - 1 - self.masking_from_epoch) * ( 144 | current_epoch + 1 - self.masking_from_epoch)) # Mask 100 % to 0 % 145 | else: 146 | percentage = self.moving_mask_percent 147 | try: 148 | self.iou_thresh['non_moving'] = np.percentile(self.iou_log['non_moving'], (100 * percentage)).item() 149 | except Exception as e: 150 | self.iou_thresh['non_moving'] = 0.0 151 | print('Error calculating percentile of non_moving') 152 | print(e) 153 | try: 154 | self.iou_thresh['moving'] = np.percentile(self.iou_log['moving'], (100 * percentage)).item() 155 | except Exception as e: 156 | self.iou_thresh['moving'] = 0.0 157 | print('Error calculating percentile of moving') 158 | print(e) 159 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: torch_110 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _ipyw_jlab_nb_ext_conf=0.1.0=py37_0 7 | - _libgcc_mutex=0.1=main 8 | - alabaster=0.7.12=py37_0 9 | - anaconda-client=1.7.2=py37_0 10 | - anaconda-project=0.8.4=py_0 11 | - argh=0.26.2=py37_0 12 | - asn1crypto=1.3.0=py37_0 13 | - astroid=2.3.3=py37_0 14 | - astropy=4.0=py37h7b6447c_0 15 | - atomicwrites=1.3.0=py37_1 16 | - attrs=19.3.0=py_0 17 | - autopep8=1.4.4=py_0 18 | - babel=2.8.0=py_0 19 | - backcall=0.1.0=py37_0 20 | - backports=1.0=py_2 21 | - backports.functools_lru_cache=1.6.1=py_0 22 | - backports.shutil_get_terminal_size=1.0.0=py37_2 23 | - backports.tempfile=1.0=py_1 24 | - backports.weakref=1.0.post1=py_1 25 | - beautifulsoup4=4.8.2=py37_0 26 | - bitarray=1.2.1=py37h7b6447c_0 27 | - bkcharts=0.2=py37_0 28 | - blas=1.0=mkl 29 | - bleach=3.1.0=py37_0 30 | - blosc=1.16.3=hd408876_0 31 | - boto=2.49.0=py37_0 32 | - bottleneck=1.3.2=py37heb32a55_0 33 | - bzip2=1.0.8=h7b6447c_0 34 | - ca-certificates=2020.6.24=0 35 | - cairo=1.14.12=h8948797_3 36 | - certifi=2020.6.20=py37_0 37 | - cffi=1.14.0=py37h2e261b9_0 38 | - chardet=3.0.4=py37_1003 39 | - click=7.0=py37_0 40 | - cloudpickle=1.3.0=py_0 41 | - clyent=1.2.2=py37_1 42 | - colorama=0.4.3=py_0 43 | - conda=4.6.14=py37_0 44 | - conda-package-handling=1.6.0=py37h7b6447c_0 45 | - conda-verify=3.4.2=py_1 46 | - contextlib2=0.6.0.post1=py_0 47 | - cryptography=2.8=py37h1ba5d50_0 48 | - cudatoolkit=10.0.130=0 49 | - curl=7.68.0=hbc83047_0 50 | - cycler=0.10.0=py37_0 51 | - cython=0.29.15=py37he6710b0_0 52 | - cytoolz=0.10.1=py37h7b6447c_0 53 | - dask-core=2.11.0=py_0 54 | - dbus=1.13.12=h746ee38_0 55 | - decorator=4.4.1=py_0 56 | - defusedxml=0.6.0=py_0 57 | - diff-match-patch=20181111=py_0 58 | - distributed=2.11.0=py37_0 59 | - docutils=0.16=py37_0 60 | - entrypoints=0.3=py37_0 61 | - et_xmlfile=1.0.1=py37_0 62 | - expat=2.2.6=he6710b0_0 63 | - fastcache=1.1.0=py37h7b6447c_0 64 | - filelock=3.0.12=py_0 65 | - flake8=3.7.9=py37_0 66 | - flask=1.1.1=py_0 67 | - fontconfig=2.13.0=h9420a91_0 68 | - freetype=2.9.1=h8a8886c_1 69 | - fribidi=1.0.5=h7b6447c_0 70 | - fsspec=0.6.2=py_0 71 | - future=0.18.2=py37_0 72 | - get_terminal_size=1.0.0=haa9412d_0 73 | - gevent=1.4.0=py37h7b6447c_0 74 | - glib=2.63.1=h5a9c865_0 75 | - glob2=0.7=py_0 76 | - gmp=6.1.2=h6c8ec71_1 77 | - gmpy2=2.0.8=py37h10f8cd9_2 78 | - graphite2=1.3.13=h23475e2_0 79 | - greenlet=0.4.15=py37h7b6447c_0 80 | - gst-plugins-base=1.14.0=hbbd80ab_1 81 | - gstreamer=1.14.0=hb453b48_1 82 | - h5py=2.10.0=py37h7918eee_0 83 | - harfbuzz=1.8.8=hffaf4a1_0 84 | - hdf5=1.10.4=hb1b8bf9_0 85 | - heapdict=1.0.1=py_0 86 | - html5lib=1.0.1=py37_0 87 | - hypothesis=5.5.4=py_0 88 | - icu=58.2=h9c2bf20_1 89 | - idna=2.8=py37_0 90 | - imagesize=1.2.0=py_0 91 | - importlib_metadata=1.5.0=py37_0 92 | - intel-openmp=2020.0=166 93 | - intervaltree=3.0.2=py_0 94 | - ipykernel=5.1.4=py37h39e3cac_0 95 | - ipython=7.12.0=py37h5ca1d4c_0 96 | - ipython_genutils=0.2.0=py37_0 97 | - ipywidgets=7.5.1=py_0 98 | - isort=4.3.21=py37_0 99 | - itsdangerous=1.1.0=py37_0 100 | - jbig=2.1=hdba287a_0 101 | - jdcal=1.4.1=py_0 102 | - jedi=0.14.1=py37_0 103 | - jeepney=0.4.2=py_0 104 | - jinja2=2.11.1=py_0 105 | - joblib=0.14.1=py_0 106 | - jpeg=9b=h024ee3a_2 107 | - json5=0.9.1=py_0 108 | - jsonschema=3.2.0=py37_0 109 | - jupyter=1.0.0=py37_7 110 | - jupyter_client=5.3.4=py37_0 111 | - jupyter_console=6.1.0=py_0 112 | - jupyter_core=4.6.1=py37_0 113 | - jupyterlab=1.2.6=pyhf63ae98_0 114 | - jupyterlab_server=1.0.6=py_0 115 | - keyring=21.1.0=py37_0 116 | - kiwisolver=1.1.0=py37he6710b0_0 117 | - krb5=1.17.1=h173b8e3_0 118 | - lazy-object-proxy=1.4.3=py37h7b6447c_0 119 | - ld_impl_linux-64=2.33.1=h53a641e_7 120 | - libarchive=3.3.3=h5d8350f_5 121 | - libcurl=7.68.0=h20c2e04_0 122 | - libedit=3.1.20181209=hc058e9b_0 123 | - libffi=3.2.1=hd88cf55_4 124 | - libgcc-ng=9.1.0=hdf63c60_0 125 | - libgfortran-ng=7.3.0=hdf63c60_0 126 | - liblief=0.9.0=h7725739_2 127 | - libpng=1.6.37=hbc83047_0 128 | - libsodium=1.0.16=h1bed415_0 129 | - libspatialindex=1.9.3=he6710b0_0 130 | - libssh2=1.8.2=h1ba5d50_0 131 | - libstdcxx-ng=9.1.0=hdf63c60_0 132 | - libtiff=4.1.0=h2733197_0 133 | - libtool=2.4.6=h7b6447c_5 134 | - libuuid=1.0.3=h1bed415_2 135 | - libxcb=1.13=h1bed415_1 136 | - libxml2=2.9.9=hea5a465_1 137 | - libxslt=1.1.33=h7d1a2b0_0 138 | - llvmlite=0.31.0=py37hd408876_0 139 | - locket=0.2.0=py37_1 140 | - lxml=4.5.0=py37hefd8a0e_0 141 | - lz4-c=1.8.1.2=h14c3975_0 142 | - lzo=2.10=h49e0be7_2 143 | - markupsafe=1.1.1=py37h7b6447c_0 144 | - matplotlib=3.1.3=py37_0 145 | - matplotlib-base=3.1.3=py37hef1b27d_0 146 | - mccabe=0.6.1=py37_1 147 | - mistune=0.8.4=py37h7b6447c_0 148 | - mkl=2020.0=166 149 | - mkl-service=2.3.0=py37he904b0f_0 150 | - mkl_fft=1.0.15=py37ha843d7b_0 151 | - mkl_random=1.1.0=py37hd6b4f25_0 152 | - mock=4.0.1=py_0 153 | - more-itertools=8.2.0=py_0 154 | - mpc=1.1.0=h10f8cd9_1 155 | - mpfr=4.0.1=hdf1c602_3 156 | - mpmath=1.1.0=py37_0 157 | - msgpack-python=0.6.1=py37hfd86e86_1 158 | - multipledispatch=0.6.0=py37_0 159 | - navigator-updater=0.2.1=py37_0 160 | - nbconvert=5.6.1=py37_0 161 | - nbformat=5.0.4=py_0 162 | - ncurses=6.2=he6710b0_0 163 | - networkx=2.4=py_0 164 | - ninja=1.9.0=py37hfd86e86_0 165 | - nltk=3.4.5=py37_0 166 | - nose=1.3.7=py37_2 167 | - notebook=6.0.3=py37_0 168 | - numba=0.48.0=py37h0573a6f_0 169 | - numexpr=2.7.1=py37h423224d_0 170 | - numpy=1.18.1=py37h4f9e942_0 171 | - numpy-base=1.18.1=py37hde5b4d6_1 172 | - numpydoc=0.9.2=py_0 173 | - olefile=0.46=py37_0 174 | - openpyxl=3.0.3=py_0 175 | - openssl=1.1.1g=h7b6447c_0 176 | - packaging=20.1=py_0 177 | - pandas=1.0.1=py37h0573a6f_0 178 | - pandoc=2.2.3.2=0 179 | - pandocfilters=1.4.2=py37_1 180 | - pango=1.42.4=h049681c_0 181 | - parso=0.5.2=py_0 182 | - partd=1.1.0=py_0 183 | - patchelf=0.10=he6710b0_0 184 | - path=13.1.0=py37_0 185 | - path.py=12.4.0=0 186 | - pathlib2=2.3.5=py37_0 187 | - pathtools=0.1.2=py_1 188 | - patsy=0.5.1=py37_0 189 | - pcre=8.43=he6710b0_0 190 | - pep8=1.7.1=py37_0 191 | - pexpect=4.8.0=py37_0 192 | - pickleshare=0.7.5=py37_0 193 | - pillow=6.1.0=py37h34e0f95_0 194 | - pip=20.0.2=py37_1 195 | - pixman=0.38.0=h7b6447c_0 196 | - pkginfo=1.5.0.1=py37_0 197 | - pluggy=0.13.1=py37_0 198 | - ply=3.11=py37_0 199 | - prometheus_client=0.7.1=py_0 200 | - prompt_toolkit=3.0.3=py_0 201 | - psutil=5.6.7=py37h7b6447c_0 202 | - ptyprocess=0.6.0=py37_0 203 | - py=1.8.1=py_0 204 | - py-lief=0.9.0=py37h7725739_2 205 | - pycodestyle=2.5.0=py37_0 206 | - pycosat=0.6.3=py37h7b6447c_0 207 | - pycparser=2.19=py37_0 208 | - pycrypto=2.6.1=py37h14c3975_9 209 | - pycurl=7.43.0.5=py37h1ba5d50_0 210 | - pydocstyle=4.0.1=py_0 211 | - pyflakes=2.1.1=py37_0 212 | - pygments=2.5.2=py_0 213 | - pylint=2.4.4=py37_0 214 | - pyodbc=4.0.30=py37he6710b0_0 215 | - pyopenssl=19.1.0=py37_0 216 | - pyparsing=2.4.6=py_0 217 | - pyqt=5.9.2=py37h05f1152_2 218 | - pyrsistent=0.15.7=py37h7b6447c_0 219 | - pysocks=1.7.1=py37_0 220 | - pytables=3.6.1=py37h71ec239_0 221 | - pytest=5.3.5=py37_0 222 | - pytest-arraydiff=0.3=py37h39e3cac_0 223 | - pytest-astropy=0.8.0=py_0 224 | - pytest-astropy-header=0.1.2=py_0 225 | - pytest-doctestplus=0.5.0=py_0 226 | - pytest-openfiles=0.4.0=py_0 227 | - pytest-remotedata=0.3.2=py37_0 228 | - python=3.7.6=h0371630_2 229 | - python-dateutil=2.8.1=py_0 230 | - python-jsonrpc-server=0.3.4=py_0 231 | - python-language-server=0.31.7=py37_0 232 | - python-libarchive-c=2.8=py37_13 233 | - pytorch=1.1.0=py3.7_cuda10.0.130_cudnn7.5.1_0 234 | - pytz=2019.3=py_0 235 | - pywavelets=1.1.1=py37h7b6447c_0 236 | - pyxdg=0.26=py_0 237 | - pyyaml=5.3=py37h7b6447c_0 238 | - pyzmq=18.1.1=py37he6710b0_0 239 | - qdarkstyle=2.8=py_0 240 | - qt=5.9.7=h5867ecd_1 241 | - qtawesome=0.6.1=py_0 242 | - qtconsole=4.6.0=py_1 243 | - qtpy=1.9.0=py_0 244 | - readline=7.0=h7b6447c_5 245 | - requests=2.22.0=py37_1 246 | - ripgrep=11.0.2=he32d670_0 247 | - rope=0.16.0=py_0 248 | - rtree=0.9.3=py37_0 249 | - ruamel_yaml=0.15.87=py37h7b6447c_0 250 | - scikit-learn=0.22.1=py37hd81dba3_0 251 | - scipy=1.4.1=py37h0b6359f_0 252 | - seaborn=0.10.0=py_0 253 | - secretstorage=3.1.2=py37_0 254 | - send2trash=1.5.0=py37_0 255 | - setuptools=45.2.0=py37_0 256 | - simplegeneric=0.8.1=py37_2 257 | - singledispatch=3.4.0.3=py37_0 258 | - sip=4.19.8=py37hf484d3e_0 259 | - six=1.14.0=py37_0 260 | - snappy=1.1.7=hbae5bb6_3 261 | - snowballstemmer=2.0.0=py_0 262 | - sortedcollections=1.1.2=py37_0 263 | - sortedcontainers=2.1.0=py37_0 264 | - soupsieve=1.9.5=py37_0 265 | - sphinx=2.4.0=py_0 266 | - sphinxcontrib=1.0=py37_1 267 | - sphinxcontrib-applehelp=1.0.1=py_0 268 | - sphinxcontrib-devhelp=1.0.1=py_0 269 | - sphinxcontrib-htmlhelp=1.0.2=py_0 270 | - sphinxcontrib-jsmath=1.0.1=py_0 271 | - sphinxcontrib-qthelp=1.0.2=py_0 272 | - sphinxcontrib-serializinghtml=1.1.3=py_0 273 | - sphinxcontrib-websupport=1.2.0=py_0 274 | - spyder=4.0.1=py37_0 275 | - spyder-kernels=1.8.1=py37_0 276 | - sqlalchemy=1.3.13=py37h7b6447c_0 277 | - sqlite=3.31.1=h7b6447c_0 278 | - statsmodels=0.11.0=py37h7b6447c_0 279 | - sympy=1.5.1=py37_0 280 | - tbb=2020.0=hfd86e86_0 281 | - tblib=1.6.0=py_0 282 | - terminado=0.8.3=py37_0 283 | - testpath=0.4.4=py_0 284 | - tk=8.6.8=hbc83047_0 285 | - toolz=0.10.0=py_0 286 | - torchvision=0.3.0=py37_cu10.0.130_1 287 | - tornado=6.0.3=py37h7b6447c_3 288 | - tqdm=4.42.1=py_0 289 | - traitlets=4.3.3=py37_0 290 | - ujson=1.35=py37h14c3975_0 291 | - unicodecsv=0.14.1=py37_0 292 | - unixodbc=2.3.7=h14c3975_0 293 | - urllib3=1.25.8=py37_0 294 | - watchdog=0.10.2=py37_0 295 | - wcwidth=0.1.8=py_0 296 | - webencodings=0.5.1=py37_1 297 | - werkzeug=1.0.0=py_0 298 | - wheel=0.34.2=py37_0 299 | - widgetsnbextension=3.5.1=py37_0 300 | - wrapt=1.11.2=py37h7b6447c_0 301 | - wurlitzer=2.0.0=py37_0 302 | - xlrd=1.2.0=py37_0 303 | - xlsxwriter=1.2.7=py_0 304 | - xlwt=1.3.0=py37_0 305 | - xmltodict=0.12.0=py_0 306 | - xz=5.2.4=h14c3975_4 307 | - yaml=0.1.7=had09818_2 308 | - yapf=0.28.0=py_0 309 | - zeromq=4.3.1=he6710b0_3 310 | - zict=1.0.0=py_0 311 | - zipp=2.2.0=py_0 312 | - zlib=1.2.11=h7b6447c_3 313 | - zstd=1.3.7=h0b5b093_0 314 | - pip: 315 | - opencv-contrib-python==4.2.0.34 316 | - protobuf==3.12.2 317 | - tensorboardx==2.0 318 | -------------------------------------------------------------------------------- /eval_depth.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Python standard library 4 | import os 5 | 6 | # Public libraries 7 | import numpy as np 8 | import torch 9 | from torchvision.utils import save_image 10 | 11 | # Local imports 12 | import colors 13 | from arguments import DepthEvaluationArguments 14 | from harness import Harness 15 | 16 | 17 | class DepthEvaluator(Harness): 18 | def _init_validation(self, opt): 19 | self.fixed_depth_scaling = opt.depth_validation_fixed_scaling 20 | self.ratio_on_validation = opt.depth_ratio_on_validation 21 | self.val_num_log_images = opt.eval_num_images 22 | 23 | def evaluate(self): 24 | print('Evaluate depth predictions:', flush=True) 25 | 26 | scores, ratios, images = self._run_depth_validation(self.val_num_log_images) 27 | 28 | for domain in scores: 29 | print(f' - Results for domain {domain}:') 30 | 31 | if len(ratios[domain]) > 0: 32 | ratios_np = np.array(ratios[domain]) 33 | if self.ratio_on_validation: 34 | dataset_split_pos = int(len(ratios_np)/4) 35 | else: 36 | dataset_split_pos = int(len(ratios_np)) 37 | ratio_median = np.median(ratios_np[:dataset_split_pos]) 38 | ratio_norm_std = np.std(ratios_np[:dataset_split_pos] / ratio_median) 39 | 40 | print(" Scaling ratios | med: {:0.3f} | std: {:0.3f}".format(ratio_median, ratio_norm_std)) 41 | 42 | metrics = scores[domain].get_scores() 43 | 44 | print("\n " + ("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3")) 45 | print(("&{: 8.3f} " * 7).format(metrics['abs_rel'], metrics['sq_rel'], 46 | metrics['rmse'], metrics['rmse_log'], 47 | metrics['delta1'], metrics['delta2'], 48 | metrics['delta3']) + "\\\\") 49 | 50 | for domain in images: 51 | domain_dir = os.path.join(self.log_path, 'eval_images', domain) 52 | os.makedirs(domain_dir, exist_ok=True) 53 | 54 | for i, (color_gt, depth_gt, depth_pred) in enumerate(images[domain]): 55 | image_path = os.path.join(domain_dir, f'img_{i}.png') 56 | 57 | logged_images = ( 58 | color_gt, 59 | colors.depth_norm_image(depth_pred), 60 | colors.depth_norm_image(depth_gt), 61 | ) 62 | 63 | save_image( 64 | torch.cat(logged_images, 2).clamp(0, 1), 65 | image_path 66 | ) 67 | 68 | self._log_gpu_memory() 69 | 70 | 71 | if __name__ == "__main__": 72 | opt = DepthEvaluationArguments().parse() 73 | 74 | if opt.model_load is None: 75 | raise Exception('You must use --model-load to select a model state directory to run evaluation on') 76 | 77 | if opt.sys_best_effort_determinism: 78 | import random 79 | 80 | torch.backends.cudnn.deterministic = True 81 | torch.backends.cudnn.benchmark = False 82 | np.random.seed(1) 83 | torch.manual_seed(1) 84 | torch.cuda.manual_seed(1) 85 | random.seed(1) 86 | 87 | evaluator = DepthEvaluator(opt) 88 | evaluator.evaluate() 89 | -------------------------------------------------------------------------------- /eval_depth.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 eval_depth.py\ 4 | --sys-best-effort-determinism \ 5 | --model-name "eval_kitti_depth" \ 6 | --model-load sgdepth_eccv_test/zhou_full/checkpoints/epoch_20 \ 7 | --depth-validation-loaders "kitti_zhou_test" -------------------------------------------------------------------------------- /eval_pose.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Python standard library 4 | import os 5 | 6 | # Public libraries 7 | import numpy as np 8 | import torch 9 | 10 | # Local imports 11 | from arguments import PoseEvaluationArguments 12 | from harness import Harness 13 | 14 | 15 | class PoseEvaluator(Harness): 16 | def _init_validation(self, opt): 17 | self.fixed_depth_scaling = opt.pose_validation_fixed_scaling 18 | self.val_num_log_images = opt.eval_num_images 19 | 20 | def evaluate(self): 21 | print('Evaluate pose predictions:', flush=True) 22 | 23 | scores = self._run_pose_validation() 24 | 25 | for domain in scores: 26 | print(f' - Results for domain {domain}:') 27 | 28 | metrics = scores[domain].get_scores() 29 | 30 | print("\n Trajectory error: {:0.3f}, std: {:0.3f}\n".format(metrics['mean'], metrics['std'])) 31 | 32 | self._log_gpu_memory() 33 | 34 | 35 | if __name__ == "__main__": 36 | opt = PoseEvaluationArguments().parse() 37 | 38 | if opt.model_load is None: 39 | raise Exception('You must use --model-load to select a model state directory to run evaluation on') 40 | 41 | if opt.sys_best_effort_determinism: 42 | import random 43 | 44 | torch.backends.cudnn.deterministic = True 45 | torch.backends.cudnn.benchmark = False 46 | np.random.seed(1) 47 | torch.manual_seed(1) 48 | torch.cuda.manual_seed(1) 49 | random.seed(1) 50 | 51 | evaluator = PoseEvaluator(opt) 52 | evaluator.evaluate() 53 | -------------------------------------------------------------------------------- /eval_pose.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | srun python3 eval_pose.py\ 4 | --sys-best-effort-determinism \ 5 | --model-name "eval_kitti_pose" \ 6 | --model-load sgdepth_eccv_test/zhou_full/checkpoints/epoch_20 \ 7 | --pose-validation-loaders "kitti_odom09_validation" 8 | -------------------------------------------------------------------------------- /eval_segmentation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Python standard library 4 | import os 5 | 6 | # Public libraries 7 | import torch 8 | from torchvision.utils import save_image 9 | 10 | # Local imports 11 | import colors 12 | from arguments import SegmentationEvaluationArguments 13 | from harness import Harness 14 | 15 | 16 | class SegmentationEvaluator(Harness): 17 | def _init_resampler(self, opt): 18 | pass 19 | 20 | def _init_validation(self, opt): 21 | self.val_num_log_images = opt.eval_num_images 22 | self.eval_name = opt.model_name 23 | 24 | def evaluate(self): 25 | print('Evaluate segmentation predictions:', flush=True) 26 | 27 | scores, images = self._run_segmentation_validation( 28 | self.val_num_log_images 29 | ) 30 | 31 | for domain in scores: 32 | print('eval_name | domain | miou | accuracy') 33 | 34 | metrics = scores[domain].get_scores() 35 | 36 | miou = metrics['meaniou'] 37 | acc = metrics['meanacc'] 38 | 39 | print(f'{self.eval_name:12} | {domain:20} | {miou:8.3f} | {acc:8.3f}', flush=True) 40 | 41 | for domain in images: 42 | domain_dir = os.path.join(self.log_path, 'eval_images', domain) 43 | os.makedirs(domain_dir, exist_ok=True) 44 | 45 | for i, (color_gt, seg_gt, seg_pred) in enumerate(images[domain]): 46 | image_path = os.path.join(domain_dir, f'img_{i}.png') 47 | 48 | logged_images = ( 49 | color_gt, 50 | colors.seg_idx_image(seg_pred), 51 | colors.seg_idx_image(seg_gt), 52 | ) 53 | 54 | save_image( 55 | torch.cat(logged_images, 2).clamp(0, 1), 56 | image_path 57 | ) 58 | 59 | self._log_gpu_memory() 60 | 61 | return scores 62 | 63 | 64 | if __name__ == "__main__": 65 | opt = SegmentationEvaluationArguments().parse() 66 | 67 | if opt.model_load is None: 68 | raise Exception('You must use --model-load to select a model state directory to run evaluation on') 69 | 70 | if opt.sys_best_effort_determinism: 71 | import random 72 | import numpy as np 73 | 74 | torch.backends.cudnn.deterministic = True 75 | torch.backends.cudnn.benchmark = False 76 | np.random.seed(1) 77 | torch.manual_seed(1) 78 | torch.cuda.manual_seed(1) 79 | random.seed(1) 80 | 81 | evaluator = SegmentationEvaluator(opt) 82 | evaluator.evaluate() 83 | -------------------------------------------------------------------------------- /eval_segmentation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 eval_segmentation.py \ 4 | --sys-best-effort-determinism \ 5 | --model-name "eval_cs" \ 6 | --model-load sgdepth_eccv_test/zhou_full/checkpoints/epoch_20 \ 7 | --segmentation-validation-loaders "cityscapes_validation" 8 | 9 | python3 eval_segmentation.py \ 10 | --sys-best-effort-determinism \ 11 | --model-name "eval_kitti" \ 12 | --model-load sgdepth_eccv_test/zhou_full/checkpoints/epoch_20 \ 13 | --segmentation-validation-loaders "kitti_2015_train" \ 14 | --segmentation-validation-resize-width 640 \ 15 | --segmentation-validation-resize-height 192 -------------------------------------------------------------------------------- /experiments/kitti_full.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 ../train.py \ 4 | --experiment-class sgdepth_eccv_test \ 5 | --model-name kitti_full \ 6 | --masking-enable \ 7 | --masking-from-epoch 15 \ 8 | --masking-linear-increase 9 | -------------------------------------------------------------------------------- /experiments/kitti_only_depth.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 ../train.py \ 4 | --experiment-class sgdepth_eccv_test \ 5 | --model-name kitti_only_depth \ 6 | --depth-training-batch-size 12 \ 7 | --segmentation-training-loaders "" \ 8 | --train-depth-grad-scale 1.0 \ 9 | --train-segmentation-grad-scale 0.0 10 | -------------------------------------------------------------------------------- /experiments/zhou_full.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 ../train.py \ 4 | --experiment-class sgdepth_eccv_test \ 5 | --model-name zhou_full \ 6 | --depth-training-loaders "kitti_zhou_train" \ 7 | --train-batches-per-epoch 7293 \ 8 | --masking-enable \ 9 | --masking-from-epoch 15 \ 10 | --masking-linear-increase 11 | 12 | -------------------------------------------------------------------------------- /experiments/zhou_only_depth.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 ../train.py \ 4 | --experiment-class sgdepth_eccv_test \ 5 | --model-name zhou_only_depth \ 6 | --depth-training-loaders "kitti_zhou_train" \ 7 | --train-batches-per-epoch 7293 \ 8 | --depth-training-batch-size 12 \ 9 | --segmentation-training-loaders "" \ 10 | --train-depth-grad-scale 1.0 \ 11 | --train-segmentation-grad-scale 0.0 12 | -------------------------------------------------------------------------------- /harness.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Python standard library 4 | import json 5 | import pickle 6 | import os 7 | 8 | # Public libraries 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as functional 12 | 13 | # IfN libraries 14 | import dataloader.file_io.get_path as get_path 15 | from dataloader.eval.metrics import DepthRunningScore, SegmentationRunningScore, PoseRunningScore 16 | 17 | # Local imports 18 | import loaders, loaders.segmentation, loaders.depth, loaders.pose, loaders.fns 19 | 20 | from state_manager import StateManager 21 | from perspective_resample import PerspectiveResampler 22 | 23 | 24 | class Harness(object): 25 | def __init__(self, opt): 26 | print('Starting initialization', flush=True) 27 | 28 | self._init_device(opt) 29 | self._init_resampler(opt) 30 | self._init_losses(opt) 31 | self._init_log_dir(opt) 32 | self._init_logging(opt) 33 | self._init_tensorboard(opt) 34 | self._init_state(opt) 35 | self._init_train_loaders(opt) 36 | self._init_training(opt) 37 | self._init_validation_loaders(opt) 38 | self._init_validation(opt) 39 | self._save_opts(opt) 40 | 41 | print('Summary:') 42 | print(f' - Model name: {opt.model_name}') 43 | print(f' - Logging directory: {self.log_path}') 44 | print(f' - Using device: {self._pretty_device_name()}') 45 | 46 | def _init_device(self, opt): 47 | cpu = not torch.cuda.is_available() 48 | cpu = cpu or opt.sys_cpu 49 | 50 | self.device = torch.device("cpu" if cpu else "cuda") 51 | 52 | def _init_resampler(self, opt): 53 | if hasattr(opt, 'depth_min_sampling_res'): 54 | self.resample = PerspectiveResampler(opt.model_depth_max, opt.model_depth_min, opt.depth_min_sampling_res) 55 | else: 56 | self.resample = PerspectiveResampler(opt.model_depth_max, opt.model_depth_min) 57 | 58 | def _init_losses(self, opt): 59 | pass 60 | 61 | def _init_log_dir(self, opt): 62 | path_getter = get_path.GetPath() 63 | log_base = path_getter.get_checkpoint_path() 64 | 65 | self.log_path = os.path.join(log_base, opt.experiment_class, opt.model_name) 66 | 67 | os.makedirs(self.log_path, exist_ok=True) 68 | 69 | def _init_logging(self, opt): 70 | pass 71 | 72 | def _init_tensorboard(self, opt): 73 | pass 74 | 75 | def _init_state(self, opt): 76 | self.state = StateManager( 77 | opt.experiment_class, opt.model_name, self.device, opt.model_split_pos, opt.model_num_layers, 78 | opt.train_depth_grad_scale, opt.train_segmentation_grad_scale, 79 | opt.train_weights_init, opt.model_depth_resolutions, opt.model_num_layers_pose, 80 | opt.train_learning_rate, opt.train_weight_decay, opt.train_scheduler_step_size 81 | ) 82 | if opt.model_load is not None: 83 | self.state.load(opt.model_load, opt.model_disable_lr_loading) 84 | 85 | def _init_train_loaders(self, opt): 86 | pass 87 | 88 | def _init_training(self, opt): 89 | pass 90 | 91 | def _init_validation_loaders(self, opt): 92 | print('Loading validation dataset metadata:', flush=True) 93 | 94 | if hasattr(opt, 'depth_validation_loaders'): 95 | self.depth_validation_loader = loaders.ChainedLoaderList( 96 | getattr(loaders.depth, loader_name)( 97 | img_height=opt.depth_validation_resize_height, 98 | img_width=opt.depth_validation_resize_width, 99 | batch_size=opt.depth_validation_batch_size, 100 | num_workers=opt.sys_num_workers 101 | ) 102 | for loader_name in opt.depth_validation_loaders.split(',') if (loader_name != '') 103 | ) 104 | 105 | if hasattr(opt, 'pose_validation_loaders'): 106 | self.pose_validation_loader = loaders.ChainedLoaderList( 107 | getattr(loaders.pose, loader_name)( 108 | img_height=opt.pose_validation_resize_height, 109 | img_width=opt.pose_validation_resize_width, 110 | batch_size=opt.pose_validation_batch_size, 111 | num_workers=opt.sys_num_workers 112 | ) 113 | for loader_name in opt.pose_validation_loaders.split(',') if (loader_name != '') 114 | ) 115 | 116 | if hasattr(opt, 'segmentation_validation_loaders'): 117 | self.segmentation_validation_loader = loaders.ChainedLoaderList( 118 | getattr(loaders.segmentation, loader_name)( 119 | resize_height=opt.segmentation_validation_resize_height, 120 | resize_width=opt.segmentation_validation_resize_width, 121 | batch_size=opt.segmentation_validation_batch_size, 122 | num_workers=opt.sys_num_workers 123 | ) 124 | for loader_name in opt.segmentation_validation_loaders.split(',') if (loader_name != '') 125 | ) 126 | 127 | def _init_validation(self, opt): 128 | self.fixed_depth_scaling = opt.depth_validation_fixed_scaling 129 | 130 | def _pretty_device_name(self): 131 | dev_type = self.device.type 132 | 133 | dev_idx = ( 134 | f',{self.device.index}' 135 | if (self.device.index is not None) 136 | else '' 137 | ) 138 | 139 | dev_cname = ( 140 | f' ({torch.cuda.get_device_name(self.device)})' 141 | if (dev_type == 'cuda') 142 | else '' 143 | ) 144 | 145 | return f'{dev_type}{dev_idx}{dev_cname}' 146 | 147 | def _log_gpu_memory(self): 148 | if self.device.type == 'cuda': 149 | max_mem = torch.cuda.max_memory_allocated(self.device) 150 | 151 | print('Maximum MB of GPU memory used:') 152 | print(str(max_mem/(1024**2))) 153 | 154 | def _save_opts(self, opt): 155 | opt_path = os.path.join(self.log_path, 'opt.json') 156 | 157 | with open(opt_path, 'w') as fd: 158 | json.dump(vars(opt), fd, indent=2) 159 | 160 | def _batch_to_device(self, batch_cpu): 161 | batch_gpu = list() 162 | 163 | for dataset_cpu in batch_cpu: 164 | dataset_gpu = dict() 165 | 166 | for k, ipt in dataset_cpu.items(): 167 | if isinstance(ipt, torch.Tensor): 168 | dataset_gpu[k] = ipt.to(self.device) 169 | 170 | else: 171 | dataset_gpu[k] = ipt 172 | 173 | batch_gpu.append(dataset_gpu) 174 | 175 | return tuple(batch_gpu) 176 | 177 | def _validate_batch_depth(self, model, batch, score, ratios, images): 178 | if len(batch) != 1: 179 | raise Exception('Can only run validation on batches containing only one dataset') 180 | 181 | im_scores = list() 182 | 183 | batch_gpu = self._batch_to_device(batch) 184 | outputs = model(batch_gpu) 185 | 186 | colors_gt = batch[0]['color', 0, -1] 187 | depths_gt = batch[0]['depth', 0, 0][:, 0] 188 | 189 | disps_pred = outputs[0]["disp", 0] 190 | disps_scaled_pred = self.resample.scale_disp(disps_pred) 191 | disps_scaled_pred = disps_scaled_pred.cpu()[:, 0] 192 | 193 | # Process each image from the batch separately 194 | for i in range(depths_gt.shape[0]): 195 | # If you are here due to an exception, make sure that your loader uses 196 | # AddKeyValue('domain', domain_name), AddKeyValue('validation_mask', mask_fn) 197 | # and AddKeyValue('validation_clamp', clamp_fn) to add these keys to each input sample. 198 | # There is no sensible default, that works for all datasets, 199 | # so you have have to define one on a per-set basis. 200 | domain = batch[0]['domain'][i] 201 | mask_fn = loaders.fns.get(batch[0]['validation_mask'][i]) 202 | clamp_fn = loaders.fns.get(batch[0]['validation_clamp'][i]) 203 | 204 | color_gt = colors_gt[i].unsqueeze(0) 205 | depth_gt = depths_gt[i].unsqueeze(0) 206 | disp_scaled_pred = disps_scaled_pred[i].unsqueeze(0) 207 | 208 | img_height = depth_gt.shape[1] 209 | img_width = depth_gt.shape[2] 210 | disp_scaled_pred = functional.interpolate( 211 | disp_scaled_pred.unsqueeze(1), 212 | (img_height, img_width), 213 | align_corners=False, 214 | mode='bilinear' 215 | ).squeeze(1) 216 | depth_pred = 1 / disp_scaled_pred 217 | 218 | images.append((color_gt, depth_gt, depth_pred)) 219 | 220 | # Datasets/splits define their own masking rules 221 | # delegate masking to functions defined in the loader 222 | mask = mask_fn(depth_gt) 223 | depth_pred = depth_pred[mask] 224 | depth_gt = depth_gt[mask] 225 | 226 | if self.fixed_depth_scaling != 0: 227 | ratio = self.fixed_depth_scaling 228 | 229 | else: 230 | median_gt = np.median(depth_gt.numpy()) 231 | median_pred = np.median(depth_pred.numpy()) 232 | 233 | ratio = (median_gt / median_pred).item() 234 | 235 | ratios.append(ratio) 236 | depth_pred *= ratio 237 | 238 | # Datasets/splits define their own prediction clamping rules 239 | # delegate clamping to functions defined in the loader 240 | depth_pred = clamp_fn(depth_pred) 241 | 242 | score.update( 243 | depth_gt.numpy(), 244 | depth_pred.numpy() 245 | ) 246 | 247 | return im_scores 248 | 249 | def _validate_batch_segmentation(self, model, batch, score, images): 250 | if len(batch) != 1: 251 | raise Exception('Can only run validation on batches containing only one dataset') 252 | 253 | im_scores = list() 254 | 255 | batch_gpu = self._batch_to_device(batch) 256 | outputs = model(batch_gpu) # forward the data through the network 257 | 258 | colors_gt = batch[0]['color', 0, -1] 259 | segs_gt = batch[0]['segmentation', 0, 0].squeeze(1).long() # shape [1,1024,2048] 260 | segs_pred = outputs[0]['segmentation_logits', 0] # shape [1,20,192,640] one for every class 261 | segs_pred = functional.interpolate(segs_pred, segs_gt[0, :, :].shape, mode='nearest') # upscale predictions 262 | 263 | for i in range(segs_pred.shape[0]): 264 | color_gt = colors_gt[i].unsqueeze(0) 265 | seg_gt = segs_gt[i].unsqueeze(0) 266 | seg_pred = segs_pred[i].unsqueeze(0) 267 | 268 | images.append((color_gt, seg_gt, seg_pred.argmax(1).cpu())) 269 | 270 | seg_pred = seg_pred.exp().cpu() # exp preds and shift to CPU 271 | seg_pred = seg_pred.numpy() # transform preds to np array 272 | seg_pred = seg_pred.argmax(1) # get the highest score for classes per pixel 273 | seg_gt = seg_gt.numpy() # transform gt to np array 274 | 275 | score.update(seg_gt, seg_pred) 276 | 277 | return im_scores 278 | 279 | def _validate_batch_pose(self, model, batch, score): 280 | if len(batch) != 1: 281 | raise Exception('Can only run validation on batches containing only one dataset') 282 | 283 | batch_gpu = self._batch_to_device(batch) 284 | outputs = model(batch_gpu) 285 | poses_pred = outputs[0][("cam_T_cam", 0, 1)] 286 | poses_gt = batch[0][('poses', 0, -1)] 287 | 288 | for i in range(poses_pred.shape[0]): 289 | pose_gt = poses_gt[i].unsqueeze(0).cpu().numpy() 290 | pose_pred = poses_pred[i].squeeze(0).cpu().numpy() 291 | score.update(pose_gt, pose_pred) 292 | 293 | def _validate_batch_joint(self, model, batch, depth_score, depth_ratios, depth_images, 294 | seg_score, seg_images, seg_perturbations, 295 | seg_im_scores, depth_im_scores): 296 | 297 | # apply a perturbation onto the input image 298 | loss_fn = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=255) 299 | batch, seg_perturbation = self.attack_model.perturb(batch, model, loss_fn) 300 | seg_perturbations.append(seg_perturbation) 301 | 302 | # pass the evaluation to the single evaluation routines 303 | with torch.no_grad(): 304 | seg_im_scores.extend( 305 | self._validate_batch_segmentation(model, batch, seg_score, seg_images) 306 | ) 307 | depth_im_scores.extend( 308 | self._validate_batch_depth(model, batch, depth_score, depth_ratios, depth_images) 309 | ) 310 | 311 | def _run_depth_validation(self, images_to_keep=0): 312 | scores = dict() 313 | ratios = dict() 314 | images = dict() 315 | 316 | with torch.no_grad(), self.state.model_manager.get_eval() as model: 317 | for batch in self.depth_validation_loader: 318 | domain = batch[0]['domain'][0] 319 | 320 | if domain not in scores: 321 | scores[domain] = DepthRunningScore() 322 | ratios[domain] = list() 323 | images[domain] = list() 324 | 325 | _ = self._validate_batch_depth(model, batch, scores[domain], ratios[domain], images[domain]) 326 | 327 | images[domain] = images[domain][:images_to_keep] 328 | 329 | return scores, ratios, images 330 | 331 | def _run_pose_validation(self): 332 | scores = dict() 333 | 334 | with torch.no_grad(), self.state.model_manager.get_eval() as model: 335 | for batch in self.pose_validation_loader: 336 | 337 | domain = batch[0]['domain'][0] 338 | 339 | if domain not in scores: 340 | scores[domain] = PoseRunningScore() 341 | 342 | self._validate_batch_pose(model, batch, scores[domain]) 343 | 344 | return scores 345 | 346 | def _run_segmentation_validation(self, images_to_keep=0): 347 | scores = dict() 348 | images = dict() 349 | 350 | # torch.no_grad() = disable gradient calculation 351 | with torch.no_grad(), self.state.model_manager.get_eval() as model: 352 | for batch in self.segmentation_validation_loader: 353 | domain = batch[0]['domain'][0] 354 | num_classes = batch[0]['num_classes'][0].item() 355 | 356 | if domain not in scores: 357 | scores[domain] = SegmentationRunningScore(num_classes) 358 | images[domain] = list() 359 | 360 | _ = self._validate_batch_segmentation(model, batch, scores[domain], images[domain]) 361 | 362 | images[domain] = images[domain][:images_to_keep] 363 | 364 | return scores, images 365 | -------------------------------------------------------------------------------- /imgs/ECCV_presentation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifnspaml/SGDepth/7d594b5153801d8240ec809ef6065eff1cd1f4fc/imgs/ECCV_presentation.jpg -------------------------------------------------------------------------------- /imgs/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifnspaml/SGDepth/7d594b5153801d8240ec809ef6065eff1cd1f4fc/imgs/intro.png -------------------------------------------------------------------------------- /imgs/qualitative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifnspaml/SGDepth/7d594b5153801d8240ec809ef6065eff1cd1f4fc/imgs/qualitative.png -------------------------------------------------------------------------------- /imgs/sg_depth.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifnspaml/SGDepth/7d594b5153801d8240ec809ef6065eff1cd1f4fc/imgs/sg_depth.gif -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import time 2 | from models.sgdepth import SGDepth 3 | import torch 4 | from arguments import InferenceEvaluationArguments 5 | import cv2 6 | import os 7 | import torchvision.transforms as transforms 8 | from PIL import Image 9 | import numpy as np 10 | import glob as glob 11 | 12 | 13 | DEBUG = False # if this flag is set the images are displayed before being saved 14 | 15 | 16 | class Inference: 17 | """Inference without harness or dataloader""" 18 | 19 | def __init__(self): 20 | self.model_path = opt.model_path 21 | self.image_dir = opt.image_path 22 | self.image_path = opt.image_path 23 | self.num_classes = 20 24 | self.depth_min = opt.model_depth_min 25 | self.depth_max = opt.model_depth_max 26 | self.output_path = opt.output_path 27 | self.output_format = opt.output_format 28 | self.all_time = [] 29 | # try: 30 | # self.checkpoint_path = os.environ['IFN_DIR_CHECKPOINT'] 31 | # except KeyError: 32 | # print('No IFN_DIR_CHECKPOINT defined.') 33 | 34 | self.labels = (('CLS_ROAD', (128, 64, 128)), 35 | ('CLS_SIDEWALK', (244, 35, 232)), 36 | ('CLS_BUILDING', (70, 70, 70)), 37 | ('CLS_WALL', (102, 102, 156)), 38 | ('CLS_FENCE', (190, 153, 153)), 39 | ('CLS_POLE', (153, 153, 153)), 40 | ('CLS_TRLIGHT', (250, 170, 30)), 41 | ('CLS_TRSIGN', (220, 220, 0)), 42 | ('CLS_VEGT', (107, 142, 35)), 43 | ('CLS_TERR', (152, 251, 152)), 44 | ('CLS_SKY', (70, 130, 180)), 45 | ('CLS_PERSON', (220, 20, 60)), 46 | ('CLS_RIDER', (255, 0, 0)), 47 | ('CLS_CAR', (0, 0, 142)), 48 | ('CLS_TRUCK', (0, 0, 70)), 49 | ('CLS_BUS', (0, 60, 100)), 50 | ('CLS_TRAIN', (0, 80, 100)), 51 | ('CLS_MCYCLE', (0, 0, 230)), 52 | ('CLS_BCYCLE', (119, 11, 32)), 53 | ) 54 | 55 | def init_model(self): 56 | print("Init Model...") 57 | sgdepth = SGDepth 58 | 59 | with torch.no_grad(): 60 | # init 'empty' model 61 | self.model = sgdepth( 62 | opt.model_split_pos, opt.model_num_layers, opt.train_depth_grad_scale, 63 | opt.train_segmentation_grad_scale, 64 | # opt.train_domain_grad_scale, 65 | opt.train_weights_init, opt.model_depth_resolutions, opt.model_num_layers_pose, 66 | # opt.model_num_domains, 67 | # opt.train_loss_weighting_strategy, 68 | # opt.train_grad_scale_weighting_strategy, 69 | # opt.train_gradnorm_alpha, 70 | # opt.train_uncertainty_eta_depth, 71 | # opt.train_uncertainty_eta_seg, 72 | # opt.model_shared_encoder_batchnorm_momentum 73 | ) 74 | 75 | # load weights (copied from state manager) 76 | state = self.model.state_dict() 77 | to_load = torch.load(self.model_path) 78 | for (k, v) in to_load.items(): 79 | if k not in state: 80 | print(f" - WARNING: Model file contains unknown key {k} ({list(v.shape)})") 81 | 82 | for (k, v) in state.items(): 83 | if k not in to_load: 84 | print(f" - WARNING: Model file does not contain key {k} ({list(v.shape)})") 85 | 86 | else: 87 | state[k] = to_load[k] 88 | 89 | self.model.load_state_dict(state) 90 | self.model = self.model.eval().cuda() # for inference model should be in eval mode and on gpu 91 | 92 | def load_image(self): 93 | print("Load Image: " + self.image_path) 94 | 95 | self.image = Image.open(self.image_path) # open PIL image 96 | self.image_o_width, self.image_o_height = self.image.size 97 | 98 | resize = transforms.Resize( 99 | (opt.inference_resize_height, opt.inference_resize_width)) 100 | image = resize(self.image) # resize to argument size 101 | 102 | #center_crop = transforms.CenterCrop((opt.inference_crop_height, opt.inference_crop_width)) 103 | #image = center_crop(image) # crop to input size 104 | 105 | to_tensor = transforms.ToTensor() # transform to tensor 106 | 107 | self.input_image = to_tensor(image) # save tensor image to self.input_image for saving later 108 | image = self.normalize(self.input_image) 109 | 110 | image = image.unsqueeze(0).float().cuda() 111 | 112 | # simulate structure of batch: 113 | image_dict = {('color_aug', 0, 0): image} # dict 114 | image_dict[('color', 0, 0)] = image 115 | image_dict['domain'] = ['cityscapes_val_seg', ] 116 | image_dict['purposes'] = [['segmentation', ], ['depth', ]] 117 | image_dict['num_classes'] = torch.tensor([self.num_classes]) 118 | image_dict['domain_idx'] = torch.tensor(0) 119 | self.batch = (image_dict,) # batch tuple 120 | 121 | 122 | def normalize(self, tensor): 123 | mean = (0.485, 0.456, 0.406) 124 | std = (0.229, 0.224, 0.225) 125 | 126 | normalize = transforms.Normalize(mean, std) 127 | tensor = normalize(tensor) 128 | 129 | return tensor 130 | 131 | 132 | def inference(self): 133 | self.init_model() 134 | print('Saving images to' + str(self.output_path) + ' in ' + str(self.output_format) + '\n \n') 135 | 136 | for image_path in glob.glob(self.image_dir + '/*'): 137 | self.image_path = image_path # for output 138 | 139 | # load image and transform it in necessary batch format 140 | self.load_image() 141 | 142 | start = time.time() 143 | with torch.no_grad(): 144 | output = self.model(self.batch) # forward pictures 145 | 146 | self.all_time.append(time.time() - start) 147 | start = 0 148 | 149 | disps_pred = output[0]["disp", 0] # depth results 150 | segs_pred = output[0]['segmentation_logits', 0] # seg results 151 | 152 | segs_pred = segs_pred.exp().cpu() 153 | segs_pred = segs_pred.numpy() # transform preds to np array 154 | segs_pred = segs_pred.argmax(1) # get the highest score for classes per pixel 155 | 156 | self.save_pred_to_disk(segs_pred, disps_pred) # saves results 157 | 158 | print("Done with all pictures in: " + str(self.output_path)) 159 | print("\nAverage forward time for processing one Image (the first one excluded): ", np.average(self.all_time[1::])) 160 | 161 | def save_pred_to_disk(self, segs_pred, depth_pred): 162 | ## Segmentation visualization 163 | segs_pred = segs_pred[0] 164 | o_size = segs_pred.shape 165 | 166 | # init of seg image 167 | seg_img_array = np.zeros((3, segs_pred.shape[0], segs_pred.shape[1])) 168 | 169 | # create a color image from the classes for every pixel todo: probably a lot faster if vectorized with numpy 170 | i = 0 171 | while i < segs_pred.shape[0]: # for row 172 | n = 0 173 | while n < segs_pred.shape[1]: # for column 174 | lab = 0 175 | while lab < self.num_classes: # for classes 176 | if segs_pred[i, n] == lab: 177 | # write colors to pixel 178 | seg_img_array[0, i, n] = self.labels[lab][1][0] 179 | seg_img_array[1, i, n] = self.labels[lab][1][1] 180 | seg_img_array[2, i, n] = self.labels[lab][1][2] 181 | break 182 | lab += 1 183 | n += 1 184 | i += 1 185 | 186 | # scale the color values to 0-1 for proper visualization of OpenCV 187 | seg_img = seg_img_array.transpose(1, 2, 0).astype(np.uint8) 188 | seg_img = seg_img[:, :, ::-1 ] 189 | 190 | if DEBUG: 191 | cv2.imshow('segmentation', seg_img) 192 | cv2.waitKey() 193 | 194 | # Depth Visualization 195 | depth_pred = np.array(depth_pred[0][0].cpu()) # depth predictions to numpy and CPU 196 | 197 | depth_pred = self.scale_depth(depth_pred) # Depthmap in meters 198 | depth_pred = depth_pred * (255 / depth_pred.max()) # Normalize Depth to 255 = max depth 199 | depth_pred = np.clip(depth_pred, 0, 255) # Clip to 255 for safety 200 | depth_pred = depth_pred.astype(np.uint8) # Cast to uint8 for openCV to display 201 | 202 | depth_img = cv2.applyColorMap(depth_pred, cv2.COLORMAP_PLASMA) # Use PLASMA Colormap like in the Paper 203 | 204 | if DEBUG: 205 | cv2.imshow('depth', depth_img) 206 | cv2.waitKey() 207 | 208 | 209 | # Color_img 210 | color_img = np.array(self.image) 211 | # color_img = color_img.transpose((1, 2, 0)) 212 | color_img = color_img[: ,: , ::-1] 213 | 214 | if DEBUG: 215 | cv2.imshow('color', color_img) 216 | cv2.waitKey() 217 | 218 | # resize depth and seg 219 | depth_img = cv2.resize(depth_img, (self.image_o_width, self.image_o_height)) 220 | seg_img = cv2.resize(seg_img, (self.image_o_width, self.image_o_height), interpolation=cv2.INTER_NEAREST) 221 | 222 | # Concetenate all 3 pictures together 223 | conc_img = np.concatenate((color_img, seg_img, depth_img), axis=0) 224 | 225 | if DEBUG: 226 | cv2.imshow('conc', conc_img) 227 | cv2.waitKey() 228 | 229 | img_head, img_tail = os.path.split(self.image_path) 230 | img_name = img_tail.split('.')[0] 231 | print('Saving...') 232 | cv2.imwrite(str(self.output_path +'/' + img_name + self.output_format), conc_img) 233 | 234 | 235 | def scale_depth(self, disp): 236 | min_disp = 1 / self.depth_max 237 | max_disp = 1 / self.depth_min 238 | return min_disp + (max_disp - min_disp) * disp 239 | 240 | 241 | if __name__ == "__main__": 242 | opt = InferenceEvaluationArguments().parse() 243 | 244 | infer = Inference() 245 | infer.inference() 246 | -------------------------------------------------------------------------------- /loaders/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import itertools as it 3 | 4 | class LoaderList(object): 5 | def __init__(self, loaders): 6 | self.loaders = tuple(loaders) 7 | 8 | def __iter__(self): 9 | raise NotImplementedError() 10 | 11 | class FixedLengthLoaderList(LoaderList): 12 | def __init__(self, loaders, length): 13 | super().__init__(loaders) 14 | 15 | self.length = length 16 | 17 | def __iter__(self): 18 | infinite_iters = tuple( 19 | self._make_infinite(domain_idx, loader) 20 | for domain_idx, loader in enumerate(self.loaders) 21 | ) 22 | 23 | length_iter = range(self.length) 24 | 25 | for batch_idx, *group in zip(length_iter, *infinite_iters): 26 | yield tuple(group) 27 | 28 | def _make_infinite(self, domain_idx, loader): 29 | while True: 30 | for batch in loader: 31 | batch['domain_idx'] = torch.tensor(domain_idx) 32 | 33 | yield batch 34 | 35 | 36 | class ChainedLoaderList(LoaderList): 37 | def __iter__(self): 38 | for domain_idx, loader in enumerate(self.loaders): 39 | for batch in loader: 40 | batch['domain_idx'] = torch.tensor(domain_idx) 41 | 42 | yield (batch, ) 43 | -------------------------------------------------------------------------------- /loaders/depth/__init__.py: -------------------------------------------------------------------------------- 1 | from .train import * 2 | from .validation import * 3 | -------------------------------------------------------------------------------- /loaders/depth/train.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, ConcatDataset 2 | 3 | from dataloader.pt_data_loader.specialdatasets import StandardDataset 4 | import dataloader.pt_data_loader.mytransforms as tf 5 | 6 | 7 | def kitti_zhou_train(resize_height, resize_width, crop_height, crop_width, batch_size, num_workers): 8 | """A loader that loads image sequences for depth training from the 9 | kitti training set. 10 | This loader returns sequences from the left camera, as well as from the right camera. 11 | """ 12 | 13 | transforms_common = [ 14 | tf.RandomHorizontalFlip(), 15 | tf.CreateScaledImage(), 16 | tf.Resize( 17 | (resize_height, resize_width), 18 | image_types=('color', 'depth', 'camera_intrinsics', 'K') 19 | ), 20 | tf.ConvertDepth(), 21 | tf.CreateColoraug(new_element=True), 22 | tf.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, gamma=0.0, fraction=0.5), 23 | tf.RemoveOriginals(), 24 | tf.ToTensor(), 25 | tf.NormalizeZeroMean(), 26 | tf.AddKeyValue('domain', 'kitti_zhou_train_depth'), 27 | tf.AddKeyValue('purposes', ('depth', 'domain')), 28 | ] 29 | 30 | dataset_name = 'kitti' 31 | 32 | cfg_common = { 33 | 'dataset': dataset_name, 34 | 'trainvaltest_split': 'train', 35 | 'video_mode': 'video', 36 | 'stereo_mode': 'mono', 37 | 'split': 'zhou_split', 38 | 'video_frames': (0, -1, 1), 39 | 'disable_const_items': False 40 | } 41 | 42 | cfg_left = {'keys_to_load': ('color', ), 43 | 'keys_to_video': ('color', )} 44 | 45 | cfg_right = {'keys_to_load': ('color_right',), 46 | 'keys_to_video': ('color_right',)} 47 | 48 | dataset_left = StandardDataset( 49 | data_transforms=transforms_common, 50 | **cfg_left, 51 | **cfg_common 52 | ) 53 | 54 | dataset_right = StandardDataset( 55 | data_transforms=[tf.ExchangeStereo()] + transforms_common, 56 | **cfg_right, 57 | **cfg_common 58 | ) 59 | 60 | dataset = ConcatDataset((dataset_left, dataset_right)) 61 | 62 | loader = DataLoader( 63 | dataset, batch_size, True, 64 | num_workers=num_workers, pin_memory=True, drop_last=True 65 | ) 66 | 67 | print(f" - Can use {len(dataset)} images from the kitti (zhou_split) train split for depth training", flush=True) 68 | 69 | return loader 70 | 71 | 72 | def kitti_kitti_train(resize_height, resize_width, crop_height, crop_width, batch_size, num_workers): 73 | """A loader that loads image sequences for depth training from the kitti training set. 74 | This loader returns sequences from the left camera, as well as from the right camera. 75 | """ 76 | 77 | transforms_common = [ 78 | tf.RandomHorizontalFlip(), 79 | tf.CreateScaledImage(), 80 | tf.Resize( 81 | (resize_height, resize_width), 82 | image_types=('color', 'depth', 'camera_intrinsics', 'K') 83 | ), 84 | tf.ConvertDepth(), 85 | tf.CreateColoraug(new_element=True), 86 | tf.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, gamma=0.0, fraction=0.5), 87 | tf.RemoveOriginals(), 88 | tf.ToTensor(), 89 | tf.NormalizeZeroMean(), 90 | tf.AddKeyValue('domain', 'kitti_kitti_train_depth'), 91 | tf.AddKeyValue('purposes', ('depth', 'domain')), 92 | ] 93 | 94 | dataset_name = 'kitti' 95 | 96 | cfg_common = { 97 | 'dataset': dataset_name, 98 | 'trainvaltest_split': 'train', 99 | 'video_mode': 'video', 100 | 'stereo_mode': 'mono', 101 | 'split': 'kitti_split', 102 | 'video_frames': (0, -1, 1), 103 | 'disable_const_items': False 104 | } 105 | 106 | cfg_left = {'keys_to_load': ('color',), 107 | 'keys_to_video': ('color',)} 108 | 109 | cfg_right = {'keys_to_load': ('color_right',), 110 | 'keys_to_video': ('color_right',)} 111 | 112 | dataset_left = StandardDataset( 113 | data_transforms=transforms_common, 114 | **cfg_left, 115 | **cfg_common 116 | ) 117 | 118 | dataset_right = StandardDataset( 119 | data_transforms=[tf.ExchangeStereo()] + transforms_common, 120 | **cfg_right, 121 | **cfg_common 122 | ) 123 | 124 | dataset = ConcatDataset((dataset_left, dataset_right)) 125 | 126 | loader = DataLoader( 127 | dataset, batch_size, True, 128 | num_workers=num_workers, pin_memory=True, drop_last=True 129 | ) 130 | 131 | print(f" - Can use {len(dataset)} images from the kitti (kitti_split) train set for depth training", flush=True) 132 | 133 | return loader 134 | 135 | 136 | def kitti_odom09_train(resize_height, resize_width, crop_height, crop_width, batch_size, num_workers): 137 | """A loader that loads image sequences for depth training from the 138 | kitti training set. 139 | This loader returns sequences from the left camera, as well as from the right camera. 140 | """ 141 | 142 | transforms_common = [ 143 | tf.RandomHorizontalFlip(), 144 | tf.CreateScaledImage(), 145 | tf.Resize( 146 | (resize_height, resize_width), 147 | image_types=('color', 'depth', 'camera_intrinsics', 'K') 148 | ), 149 | tf.ConvertDepth(), 150 | tf.CreateColoraug(new_element=True), 151 | tf.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, gamma=0.0, fraction=0.5), 152 | tf.RemoveOriginals(), 153 | tf.ToTensor(), 154 | tf.NormalizeZeroMean(), 155 | tf.AddKeyValue('domain', 'kitti_odom09_train_depth'), 156 | tf.AddKeyValue('purposes', ('depth', 'domain')), 157 | ] 158 | 159 | dataset_name = 'kitti' 160 | 161 | cfg_common = { 162 | 'dataset': dataset_name, 163 | 'trainvaltest_split': 'train', 164 | 'video_mode': 'video', 165 | 'stereo_mode': 'stereo', 166 | 'split': 'odom09_split', 167 | 'video_frames': (0, -1, 1), 168 | 'disable_const_items': False 169 | } 170 | 171 | cfg_left = {'keys_to_load': ('color', ), 172 | 'keys_to_video': ('color', )} 173 | 174 | cfg_right = {'keys_to_load': ('color_right',), 175 | 'keys_to_video': ('color_right',)} 176 | 177 | dataset_left = StandardDataset( 178 | data_transforms=transforms_common, 179 | **cfg_left, 180 | **cfg_common 181 | ) 182 | 183 | dataset_right = StandardDataset( 184 | data_transforms=[tf.ExchangeStereo()] + transforms_common, 185 | **cfg_right, 186 | **cfg_common 187 | ) 188 | 189 | dataset = ConcatDataset((dataset_left, dataset_right)) 190 | 191 | loader = DataLoader( 192 | dataset, batch_size, True, 193 | num_workers=num_workers, pin_memory=True, drop_last=True 194 | ) 195 | 196 | print(f" - Can use {len(dataset)} images from the kitti (odom09_split) train split for depth training", flush=True) 197 | 198 | return loader 199 | 200 | 201 | def kitti_benchmark_train(resize_height, resize_width, crop_height, crop_width, batch_size, num_workers): 202 | """A loader that loads image sequences for depth training from the 203 | kitti training set. 204 | This loader returns sequences from the left camera, as well as from the right camera. 205 | """ 206 | 207 | transforms_common = [ 208 | tf.RandomHorizontalFlip(), 209 | tf.CreateScaledImage(), 210 | tf.Resize( 211 | (resize_height, resize_width), 212 | image_types=('color', 'depth', 'camera_intrinsics', 'K') 213 | ), 214 | tf.ConvertDepth(), 215 | tf.CreateColoraug(new_element=True), 216 | tf.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, gamma=0.0, fraction=0.5), 217 | tf.RemoveOriginals(), 218 | tf.ToTensor(), 219 | tf.NormalizeZeroMean(), 220 | tf.AddKeyValue('domain', 'kitti_benchmark_train_depth'), 221 | tf.AddKeyValue('purposes', ('depth', 'domain')), 222 | ] 223 | 224 | dataset_name = 'kitti' 225 | 226 | cfg_common = { 227 | 'dataset': dataset_name, 228 | 'trainvaltest_split': 'train', 229 | 'video_mode': 'video', 230 | 'stereo_mode': 'stereo', 231 | 'split': 'benchmark_split', 232 | 'video_frames': (0, -1, 1), 233 | 'disable_const_items': False 234 | } 235 | 236 | cfg_left = {'keys_to_load': ('color', ), 237 | 'keys_to_video': ('color', )} 238 | 239 | cfg_right = {'keys_to_load': ('color_right',), 240 | 'keys_to_video': ('color_right',)} 241 | 242 | dataset_left = StandardDataset( 243 | data_transforms=transforms_common, 244 | **cfg_left, 245 | **cfg_common 246 | ) 247 | 248 | dataset_right = StandardDataset( 249 | data_transforms=[tf.ExchangeStereo()] + transforms_common, 250 | **cfg_right, 251 | **cfg_common 252 | ) 253 | 254 | dataset = ConcatDataset((dataset_left, dataset_right)) 255 | 256 | loader = DataLoader( 257 | dataset, batch_size, True, 258 | num_workers=num_workers, pin_memory=True, drop_last=True 259 | ) 260 | 261 | print(f" - Can use {len(dataset)} images from the kitti (benchmark_split) train split for depth training", 262 | flush=True) 263 | 264 | return loader 265 | -------------------------------------------------------------------------------- /loaders/depth/validation.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from dataloader.pt_data_loader.specialdatasets import StandardDataset 4 | import dataloader.pt_data_loader.mytransforms as tf 5 | 6 | 7 | def kitti_zhou_validation(img_height, img_width, batch_size, num_workers): 8 | """A loader that loads images and depth ground truth for 9 | depth validation from the kitti validation set. 10 | """ 11 | 12 | transforms = [ 13 | tf.CreateScaledImage(True), 14 | tf.Resize( 15 | (img_height, img_width), 16 | image_types=('color', ) 17 | ), 18 | tf.ConvertDepth(), 19 | tf.CreateColoraug(), 20 | tf.ToTensor(), 21 | tf.NormalizeZeroMean(), 22 | tf.AddKeyValue('domain', 'kitti_zhou_val_depth'), 23 | tf.AddKeyValue('validation_mask', 'validation_mask_kitti_zhou'), 24 | tf.AddKeyValue('validation_clamp', 'validation_clamp_kitti'), 25 | tf.AddKeyValue('purposes', ('depth', )), 26 | ] 27 | 28 | dataset = StandardDataset( 29 | dataset='kitti', 30 | split='zhou_split', 31 | trainvaltest_split='validation', 32 | video_mode='mono', 33 | stereo_mode='mono', 34 | keys_to_load=('color', 'depth'), 35 | data_transforms=transforms, 36 | video_frames=(0, ), 37 | disable_const_items=True 38 | ) 39 | 40 | loader = DataLoader( 41 | dataset, batch_size, False, 42 | num_workers=num_workers, pin_memory=True, drop_last=False 43 | ) 44 | 45 | print(f" - Can use {len(dataset)} images from the kitti (zhou_split) validation set for depth validation", 46 | flush=True) 47 | 48 | return loader 49 | 50 | 51 | def kitti_zhou_test(img_height, img_width, batch_size, num_workers): 52 | """A loader that loads images and depth ground truth for 53 | depth evaluation from the kitti test set. 54 | """ 55 | 56 | transforms = [ 57 | tf.CreateScaledImage(True), 58 | tf.Resize( 59 | (img_height, img_width), 60 | image_types=('color', ) 61 | ), 62 | tf.ConvertDepth(), 63 | tf.CreateColoraug(), 64 | tf.ToTensor(), 65 | tf.NormalizeZeroMean(), 66 | tf.AddKeyValue('domain', 'kitti_zhou_test_depth'), 67 | tf.AddKeyValue('validation_mask', 'validation_mask_kitti_zhou'), 68 | tf.AddKeyValue('validation_clamp', 'validation_clamp_kitti'), 69 | tf.AddKeyValue('purposes', ('depth', )), 70 | ] 71 | 72 | dataset = StandardDataset( 73 | dataset='kitti', 74 | split='zhou_split', 75 | trainvaltest_split='test', 76 | video_mode='mono', 77 | stereo_mode='mono', 78 | keys_to_load=('color', 'depth'), 79 | data_transforms=transforms, 80 | video_frames=(0, ), 81 | disable_const_items=True 82 | ) 83 | 84 | loader = DataLoader( 85 | dataset, batch_size, False, 86 | num_workers=num_workers, pin_memory=True, drop_last=False 87 | ) 88 | 89 | print(f" - Can use {len(dataset)} images from the kitti (zhou_split) test set for depth evaluation", flush=True) 90 | 91 | return loader 92 | 93 | 94 | def kitti_kitti_validation(img_height, img_width, batch_size, num_workers): 95 | """A loader that loads images and depth ground truth for 96 | depth validation from the kitti validation set. 97 | """ 98 | 99 | transforms = [ 100 | tf.CreateScaledImage(True), 101 | tf.Resize( 102 | (img_height, img_width), 103 | image_types=('color', ) 104 | ), 105 | tf.ConvertDepth(), 106 | tf.CreateColoraug(), 107 | tf.ToTensor(), 108 | tf.NormalizeZeroMean(), 109 | tf.AddKeyValue('domain', 'kitti_kitti_val_depth'), 110 | tf.AddKeyValue('validation_mask', 'validation_mask_kitti_kitti'), 111 | tf.AddKeyValue('validation_clamp', 'validation_clamp_kitti'), 112 | tf.AddKeyValue('purposes', ('depth', )), 113 | ] 114 | 115 | dataset = StandardDataset( 116 | dataset='kitti', 117 | split='kitti_split', 118 | trainvaltest_split='validation', 119 | video_mode='mono', 120 | stereo_mode='mono', 121 | keys_to_load=('color', 'depth'), 122 | data_transforms=transforms, 123 | video_frames=(0, ), 124 | disable_const_items=True 125 | ) 126 | 127 | loader = DataLoader( 128 | dataset, batch_size, False, 129 | num_workers=num_workers, pin_memory=True, drop_last=False 130 | ) 131 | 132 | print(f" - Can use {len(dataset)} images from the kitti (kitti_split) validation set for depth validation", 133 | flush=True) 134 | 135 | return loader 136 | 137 | 138 | def kitti_2015_train(img_height, img_width, batch_size, num_workers): 139 | """A loader that loads images and depth ground truth for 140 | depth evaluation from the kitti_2015 training set (but for evaluation). 141 | """ 142 | 143 | transforms = [ 144 | tf.CreateScaledImage(True), 145 | tf.Resize( 146 | (img_height, img_width), 147 | image_types=('color', ) 148 | ), 149 | tf.ConvertDepth(), 150 | tf.CreateColoraug(), 151 | tf.ToTensor(), 152 | tf.NormalizeZeroMean(), 153 | tf.AddKeyValue('domain', 'kitti_2015_train_depth'), 154 | tf.AddKeyValue('validation_mask', 'validation_mask_kitti_kitti'), 155 | tf.AddKeyValue('validation_clamp', 'validation_clamp_kitti'), 156 | tf.AddKeyValue('purposes', ('depth', )), 157 | ] 158 | 159 | dataset = StandardDataset( 160 | dataset='kitti_2015', 161 | trainvaltest_split='train', 162 | video_mode='mono', 163 | stereo_mode='mono', 164 | keys_to_load=('color', 'depth'), 165 | data_transforms=transforms, 166 | video_frames=(0, ), 167 | disable_const_items=True 168 | ) 169 | 170 | loader = DataLoader( 171 | dataset, batch_size, False, 172 | num_workers=num_workers, pin_memory=True, drop_last=False 173 | ) 174 | 175 | print(f" - Can use {len(dataset)} images from the kitti_2015 test set for depth evaluation", flush=True) 176 | 177 | return loader 178 | 179 | 180 | -------------------------------------------------------------------------------- /loaders/fns.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | VAL_MIN_DEPTH = 1e-3 4 | VAL_MAX_DEPTH = 80 5 | 6 | 7 | def _validation_mask_kitti_zhou(depth_gt): 8 | # Select only points that are not too far or too near to be useful 9 | dist_mask = (depth_gt > VAL_MIN_DEPTH) & (depth_gt < VAL_MAX_DEPTH) 10 | 11 | # Mask out points that lie outside the 12 | # area of the image that contains usually 13 | # useful pixels. 14 | img_height = dist_mask.shape[1] 15 | img_width = dist_mask.shape[2] 16 | 17 | crop_top = int(0.40810811 * img_height) 18 | crop_bot = int(0.99189189 * img_height) 19 | crop_lft = int(0.03594771 * img_width) 20 | crop_rth = int(0.96405229 * img_width) 21 | 22 | crop_mask = torch.zeros_like(dist_mask) 23 | crop_mask[:, crop_top:crop_bot, crop_lft:crop_rth] = True 24 | 25 | # Combine the two masks from above 26 | # noinspection PyTypeChecker 27 | mask = dist_mask & crop_mask 28 | 29 | return mask 30 | 31 | 32 | def _validation_mask_kitti_kitti(depth_gt): 33 | mask = depth_gt > 0 34 | 35 | return mask 36 | 37 | 38 | def _validation_mask_cityscapes(depth_gt): 39 | 40 | dist_mask = (depth_gt > VAL_MIN_DEPTH) & (depth_gt < VAL_MAX_DEPTH) 41 | # Mask out points that lie outside the 42 | # area of the image that contains usually 43 | # useful pixels. 44 | img_height = dist_mask.shape[1] 45 | img_width = dist_mask.shape[2] 46 | 47 | crop_top = int(0.1 * img_height) 48 | crop_bot = int(0.7 * img_height) 49 | crop_lft = int(0.1 * img_width) 50 | crop_rth = int(0.9 * img_width) 51 | 52 | crop_mask = torch.zeros_like(dist_mask) 53 | crop_mask[:, crop_top:crop_bot, crop_lft:crop_rth] = True 54 | 55 | # Combine the two masks from above 56 | # noinspection PyTypeChecker 57 | mask = dist_mask & crop_mask 58 | 59 | return mask 60 | 61 | 62 | def _validation_clamp_kitti(depth_pred): 63 | depth_pred = depth_pred.clamp(VAL_MIN_DEPTH, VAL_MAX_DEPTH) 64 | 65 | return depth_pred 66 | 67 | 68 | def _validation_clamp_cityscapes(depth_pred): 69 | depth_pred = depth_pred.clamp(VAL_MIN_DEPTH, VAL_MAX_DEPTH) 70 | 71 | return depth_pred 72 | 73 | 74 | def get(name): 75 | return globals()[f'_{name}'] 76 | -------------------------------------------------------------------------------- /loaders/pose/__init__.py: -------------------------------------------------------------------------------- 1 | from .validation import * 2 | -------------------------------------------------------------------------------- /loaders/pose/validation.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from dataloader.pt_data_loader.specialdatasets import StandardDataset 4 | import dataloader.pt_data_loader.mytransforms as tf 5 | 6 | 7 | def kitti_odom09_validation(img_height, img_width, batch_size, num_workers): 8 | """A loader that loads images and depth ground truth for 9 | depth validation from the kitti validation set. 10 | """ 11 | 12 | transforms = [ 13 | tf.CreateScaledImage(True), 14 | tf.Resize( 15 | (img_height, img_width), 16 | image_types=('color', ) 17 | ), 18 | tf.CreateColoraug(), 19 | tf.ToTensor(), 20 | tf.NormalizeZeroMean(), 21 | tf.AddKeyValue('domain', 'kitti_odom09_val_pose'), 22 | tf.AddKeyValue('purposes', ('depth', )), 23 | ] 24 | 25 | dataset = StandardDataset( 26 | dataset='kitti', 27 | split='odom09_split', 28 | trainvaltest_split='test', 29 | video_mode='video', 30 | stereo_mode='mono', 31 | keys_to_load=('color', 'poses'), 32 | keys_to_video=('color', ), 33 | data_transforms=transforms, 34 | video_frames=(0, -1, 1), 35 | disable_const_items=True 36 | ) 37 | 38 | loader = DataLoader( 39 | dataset, batch_size, False, 40 | num_workers=num_workers, pin_memory=True, drop_last=False 41 | ) 42 | 43 | print(f" - Can use {len(dataset)} images from the kitti (odom09 split) validation set for pose validation", 44 | flush=True) 45 | 46 | return loader 47 | 48 | 49 | def kitti_odom10_validation(img_height, img_width, batch_size, num_workers): 50 | """A loader that loads images and depth ground truth for 51 | depth validation from the kitti validation set. 52 | """ 53 | 54 | transforms = [ 55 | tf.CreateScaledImage(True), 56 | tf.Resize( 57 | (img_height, img_width), 58 | image_types=('color', ) 59 | ), 60 | tf.CreateColoraug(), 61 | tf.ToTensor(), 62 | tf.NormalizeZeroMean(), 63 | tf.AddKeyValue('domain', 'kitti_odom10_val_pose'), 64 | tf.AddKeyValue('purposes', ('depth', )), 65 | ] 66 | 67 | dataset = StandardDataset( 68 | dataset='kitti', 69 | split='odom10_split', 70 | trainvaltest_split='test', 71 | video_mode='video', 72 | stereo_mode='mono', 73 | keys_to_load=('color', 'poses'), 74 | keys_to_video=('color', ), 75 | data_transforms=transforms, 76 | video_frames=(0, -1, 1), 77 | disable_const_items=True 78 | ) 79 | 80 | loader = DataLoader( 81 | dataset, batch_size, False, 82 | num_workers=num_workers, pin_memory=True, drop_last=False 83 | ) 84 | 85 | print(f" - Can use {len(dataset)} images from the kitti (odom10 split) validation set for pose validation", 86 | flush=True) 87 | 88 | return loader 89 | -------------------------------------------------------------------------------- /loaders/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from .train import * 2 | from .validation import * 3 | -------------------------------------------------------------------------------- /loaders/segmentation/train.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from dataloader.pt_data_loader.specialdatasets import StandardDataset 4 | from dataloader.definitions.labels_file import labels_cityscape_seg 5 | import dataloader.pt_data_loader.mytransforms as tf 6 | 7 | 8 | def cityscapes_train(resize_height, resize_width, crop_height, crop_width, batch_size, num_workers): 9 | """A loader that loads images and ground truth for segmentation from the 10 | cityscapes training set. 11 | """ 12 | 13 | labels = labels_cityscape_seg.getlabels() 14 | num_classes = len(labels_cityscape_seg.gettrainid2label()) 15 | 16 | transforms = [ 17 | tf.RandomHorizontalFlip(), 18 | tf.CreateScaledImage(), 19 | tf.Resize((resize_height, resize_width)), 20 | tf.RandomRescale(1.5), 21 | tf.RandomCrop((crop_height, crop_width)), 22 | tf.ConvertSegmentation(), 23 | tf.CreateColoraug(new_element=True), 24 | tf.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, gamma=0.0), 25 | tf.RemoveOriginals(), 26 | tf.ToTensor(), 27 | tf.NormalizeZeroMean(), 28 | tf.AddKeyValue('domain', 'cityscapes_train_seg'), 29 | tf.AddKeyValue('purposes', ('segmentation', 'domain')), 30 | tf.AddKeyValue('num_classes', num_classes) 31 | ] 32 | 33 | dataset_name = 'cityscapes' 34 | 35 | dataset = StandardDataset( 36 | dataset=dataset_name, 37 | trainvaltest_split='train', 38 | video_mode='mono', 39 | stereo_mode='mono', 40 | labels_mode='fromid', 41 | disable_const_items=True, 42 | labels=labels, 43 | keys_to_load=('color', 'segmentation'), 44 | data_transforms=transforms, 45 | video_frames=(0,) 46 | ) 47 | 48 | loader = DataLoader( 49 | dataset, batch_size, True, 50 | num_workers=num_workers, pin_memory=True, drop_last=True 51 | ) 52 | 53 | print(f" - Can use {len(dataset)} images from the cityscapes train set for segmentation training", flush=True) 54 | 55 | return loader 56 | -------------------------------------------------------------------------------- /loaders/segmentation/validation.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from dataloader.pt_data_loader.specialdatasets import StandardDataset 4 | from dataloader.definitions.labels_file import labels_cityscape_seg 5 | import dataloader.pt_data_loader.mytransforms as tf 6 | 7 | 8 | def cityscapes_validation(resize_height, resize_width, batch_size, num_workers): 9 | """A loader that loads images and ground truth for segmentation from the 10 | cityscapes validation set 11 | """ 12 | 13 | labels = labels_cityscape_seg.getlabels() 14 | num_classes = len(labels_cityscape_seg.gettrainid2label()) 15 | 16 | transforms = [ 17 | tf.CreateScaledImage(True), 18 | tf.Resize((resize_height, resize_width), image_types=('color', )), 19 | tf.ConvertSegmentation(), 20 | tf.CreateColoraug(), 21 | tf.ToTensor(), 22 | tf.NormalizeZeroMean(), 23 | tf.AddKeyValue('domain', 'cityscapes_val_seg'), 24 | tf.AddKeyValue('purposes', ('segmentation', )), 25 | tf.AddKeyValue('num_classes', num_classes) 26 | ] 27 | 28 | dataset = StandardDataset( 29 | dataset='cityscapes', 30 | trainvaltest_split='validation', 31 | video_mode='mono', 32 | stereo_mode='mono', 33 | labels_mode='fromid', 34 | labels=labels, 35 | keys_to_load=['color', 'segmentation'], 36 | data_transforms=transforms, 37 | disable_const_items=True 38 | ) 39 | 40 | loader = DataLoader( 41 | dataset, batch_size, False, 42 | num_workers=num_workers, pin_memory=True, drop_last=False 43 | ) 44 | 45 | print(f" - Can use {len(dataset)} images from the cityscapes validation set for segmentation validation", 46 | flush=True) 47 | 48 | return loader 49 | 50 | 51 | def kitti_2015_train(resize_height, resize_width, batch_size, num_workers): 52 | """A loader that loads images and ground truth for segmentation from the 53 | kitti_2015 train set 54 | """ 55 | 56 | labels = labels_cityscape_seg.getlabels() 57 | num_classes = len(labels_cityscape_seg.gettrainid2label()) 58 | 59 | transforms = [ 60 | tf.CreateScaledImage(True), 61 | tf.Resize((resize_height, resize_width), image_types=('color', )), 62 | tf.ConvertSegmentation(), 63 | tf.CreateColoraug(), 64 | tf.ToTensor(), 65 | tf.NormalizeZeroMean(), 66 | tf.AddKeyValue('domain', 'kitti_2015_val_seg'), 67 | tf.AddKeyValue('purposes', ('segmentation', )), 68 | tf.AddKeyValue('num_classes', num_classes) 69 | ] 70 | 71 | dataset = StandardDataset( 72 | dataset='kitti_2015', 73 | trainvaltest_split='train', 74 | video_mode='mono', 75 | stereo_mode='mono', 76 | labels_mode='fromid', 77 | labels=labels, 78 | keys_to_load=['color', 'segmentation'], 79 | data_transforms=transforms, 80 | disable_const_items=True 81 | ) 82 | 83 | loader = DataLoader( 84 | dataset, batch_size, False, 85 | num_workers=num_workers, pin_memory=True, drop_last=False 86 | ) 87 | 88 | print(f" - Can use {len(dataset)} images from the kitti_2015 train set for segmentation validation", flush=True) 89 | 90 | return loader 91 | 92 | 93 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .segmentation import SegLosses 2 | from .depth import DepthLosses 3 | from .baselosses import * 4 | -------------------------------------------------------------------------------- /losses/depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as functional 3 | 4 | import losses as trn_losses 5 | 6 | 7 | class DepthLosses(object): 8 | def __init__(self, device, disable_automasking=False, avg_reprojection=False, disparity_smoothness=0): 9 | self.automasking = not disable_automasking 10 | self.avg_reprojection = avg_reprojection 11 | self.disparity_smoothness = disparity_smoothness 12 | self.scaling_direction = "up" 13 | self.masked_supervision = True 14 | 15 | # noinspection PyUnresolvedReferences 16 | self.ssim = trn_losses.SSIM().to(device) 17 | self.smoothness = trn_losses.SmoothnessLoss() 18 | 19 | def _combined_reprojection_loss(self, pred, target): 20 | """Computes reprojection losses between a batch of predicted and target images 21 | """ 22 | 23 | # Calculate the per-color difference and the mean over all colors 24 | l1 = (pred - target).abs().mean(1, True) 25 | 26 | ssim = self.ssim(pred, target).mean(1, True) 27 | 28 | reprojection_loss = 0.85 * ssim + 0.15 * l1 29 | 30 | return reprojection_loss 31 | 32 | def _reprojection_losses(self, inputs, outputs, outputs_masked): 33 | """Compute the reprojection and smoothness losses for a minibatch 34 | """ 35 | 36 | frame_ids = tuple(frozenset(k[1] for k in outputs if k[0] == 'color')) 37 | resolutions = tuple(frozenset(k[2] for k in outputs if k[0] == 'color')) 38 | 39 | losses = dict() 40 | 41 | color = inputs["color", 0, 0] 42 | target = inputs["color", 0, 0] 43 | 44 | # Compute reprojection losses for the unwarped input images 45 | identity_reprojection_loss = tuple( 46 | self._combined_reprojection_loss(inputs["color", frame_id, 0], target) 47 | for frame_id in frame_ids 48 | ) 49 | identity_reprojection_loss = torch.cat(identity_reprojection_loss, 1) 50 | 51 | if self.avg_reprojection: 52 | identity_reprojection_loss = identity_reprojection_loss.mean(1, keepdim=True) 53 | 54 | for resolution in resolutions: 55 | # Compute reprojection losses (prev frame to cur and next frame to cur) 56 | reprojection_loss = tuple( 57 | self._combined_reprojection_loss(outputs["color", frame_id, resolution], target) 58 | for frame_id in frame_ids 59 | ) 60 | reprojection_loss = torch.cat(reprojection_loss, 1) 61 | 62 | # If avg_reprojection is disabled and automasking is enabled 63 | # there will be four "loss images" stacked in the end and 64 | # the per-pixel minimum will be selected for optimization. 65 | # Cases where this is relevant are, for example, image borders, 66 | # where information is missing, or areas occluded in one of the 67 | # input images but not all of them. 68 | # If avg_reprojection is enabled the number of images to select 69 | # the minimum loss from is reduced by average-combining them. 70 | if self.avg_reprojection: 71 | reprojection_loss = reprojection_loss.mean(1, keepdim=True) 72 | 73 | # Pixels that are equal in the (unwarped) source image 74 | # and target image (e.g. no motion) are not that helpful 75 | # and can be masked out. 76 | if self.automasking: 77 | reprojection_loss = torch.cat( 78 | (identity_reprojection_loss, reprojection_loss), 1 79 | ) 80 | # Select the per-pixel minimum loss from 81 | # (prev_unwarped, next_unwarped, prev_unwarped, prev_warped). 82 | # Pixels where the unwarped input images are selected 83 | # act as gradient black holes, as nothing is backpropagated 84 | # into the network. 85 | reprojection_loss, idxs = torch.min(reprojection_loss, dim=1) 86 | 87 | # Segmentation moving mask to mask DC objects 88 | if outputs_masked is not None: 89 | moving_mask = outputs_masked['moving_mask'] 90 | reprojection_loss = reprojection_loss * moving_mask 91 | 92 | loss = reprojection_loss.mean() 93 | 94 | if self.disparity_smoothness != 0: 95 | disp = outputs["disp", resolution] 96 | 97 | ref_color = functional.interpolate( 98 | color, disp.shape[2:], mode='bilinear', align_corners=False 99 | ) 100 | 101 | mean_disp = disp.mean((2, 3), True) 102 | norm_disp = disp / (mean_disp + 1e-7) 103 | 104 | disp_smth_loss = self.smoothness(norm_disp, ref_color) 105 | disp_smth_loss = self.disparity_smoothness * disp_smth_loss / (2 ** resolution) 106 | 107 | losses[f'disp_smth_loss/{resolution}'] = disp_smth_loss 108 | 109 | loss += disp_smth_loss 110 | 111 | losses[f'loss/{resolution}'] = loss 112 | 113 | losses['loss_depth_reprojection'] = sum( 114 | losses[f'loss/{resolution}'] 115 | for resolution in resolutions 116 | ) / len(resolutions) 117 | 118 | return losses 119 | 120 | def compute_losses(self, inputs, outputs, outputs_masked): 121 | losses = self._reprojection_losses(inputs, outputs, outputs_masked) 122 | losses['loss_depth'] = losses['loss_depth_reprojection'] 123 | 124 | return losses 125 | -------------------------------------------------------------------------------- /losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as functional 3 | 4 | from dataloader.eval.metrics import SegmentationRunningScore 5 | 6 | SEG_CLASS_WEIGHTS = ( 7 | 2.8149201869965, 6.9850029945374, 3.7890393733978, 9.9428062438965, 8 | 9.7702074050903, 9.5110931396484, 10.311357498169, 10.026463508606, 9 | 4.6323022842407, 9.5608062744141, 7.8698215484619, 9.5168733596802, 10 | 10.373730659485, 6.6616044044495, 10.260489463806, 10.287888526917, 11 | 10.289801597595, 10.405355453491, 10.138095855713, 0 12 | ) 13 | 14 | CLS_ROAD = 0 15 | CLS_SIDEWALK = 1 16 | CLS_BUILDING = 2 17 | CLS_WALL = 3 18 | CLS_FENCE = 4 19 | CLS_POLE = 5 20 | CLS_TRLIGHT = 6 21 | CLS_TRSIGN = 7 22 | CLS_VEGT = 8 23 | CLS_TERR = 9 24 | CLS_SKY = 10 25 | CLS_PERSON = 11 26 | CLS_RIDER = 12 27 | CLS_CAR = 13 28 | CLS_TRUCK = 14 29 | CLS_BUS = 15 30 | CLS_TRAIN = 16 31 | CLS_MCYCLE = 17 32 | CLS_BCYCLE = 18 33 | 34 | class SegLosses(object): 35 | def __init__(self, device): 36 | self.weights = torch.tensor(SEG_CLASS_WEIGHTS, device=device) 37 | 38 | def seg_losses(self, inputs, outputs): 39 | preds = outputs['segmentation_logits', 0] 40 | targets = inputs['segmentation', 0, 0][:, 0, :, :].long() 41 | 42 | losses = dict() 43 | losses["loss_seg"] = functional.cross_entropy( 44 | preds, targets, self.weights, ignore_index=255 45 | ) 46 | 47 | return losses 48 | 49 | 50 | class RemappingScore(object): # todo: delete this std->None. Soll direkt auf dem segmentationrunning score lauven 51 | REMAPS = { 52 | 'none': ( 53 | CLS_ROAD, CLS_SIDEWALK, CLS_BUILDING, CLS_WALL, 54 | CLS_FENCE, CLS_POLE, CLS_TRLIGHT, CLS_TRSIGN, 55 | CLS_VEGT, CLS_TERR, CLS_SKY, CLS_PERSON, 56 | CLS_RIDER, CLS_CAR, CLS_TRUCK, CLS_BUS, 57 | CLS_TRAIN, CLS_MCYCLE, CLS_BCYCLE 58 | ), 59 | 'dada_16': ( 60 | CLS_ROAD, CLS_SIDEWALK, CLS_BUILDING, CLS_WALL, 61 | CLS_FENCE, CLS_POLE, CLS_TRLIGHT, CLS_TRSIGN, 62 | CLS_VEGT, CLS_SKY, CLS_PERSON, 63 | CLS_RIDER, CLS_CAR, CLS_BUS, 64 | CLS_MCYCLE, CLS_BCYCLE 65 | ), 66 | 'dada_13': ( 67 | CLS_ROAD, CLS_SIDEWALK, CLS_BUILDING, 68 | CLS_TRLIGHT, CLS_TRSIGN, 69 | CLS_VEGT, CLS_SKY, CLS_PERSON, 70 | CLS_RIDER, CLS_CAR, CLS_BUS, 71 | CLS_MCYCLE, CLS_BCYCLE 72 | ), 73 | 'dada_7': ( 74 | (CLS_ROAD, CLS_SIDEWALK), 75 | (CLS_BUILDING, CLS_WALL, CLS_FENCE), 76 | (CLS_POLE, CLS_TRLIGHT, CLS_TRSIGN), 77 | (CLS_VEGT, CLS_TERR), 78 | CLS_SKY, 79 | (CLS_PERSON, CLS_RIDER), 80 | (CLS_CAR, CLS_TRUCK, CLS_BUS, CLS_TRAIN, CLS_MCYCLE, CLS_BCYCLE) 81 | ), 82 | 'gio_10': ( 83 | CLS_ROAD, CLS_BUILDING, 84 | CLS_POLE, CLS_TRLIGHT, CLS_TRSIGN, 85 | CLS_VEGT, CLS_TERR, CLS_SKY, 86 | CLS_CAR, CLS_TRUCK 87 | ) 88 | } 89 | 90 | def __init__(self, remaps=('none',)): 91 | self.scores = dict( 92 | (remap_name, SegmentationRunningScore(self._remap_len(remap_name))) # scores = dict mit remap name und leerer confusion matrix 93 | for remap_name in remaps 94 | ) 95 | 96 | def _remap_len(self, remap_name): 97 | return len(self.REMAPS[remap_name]) 98 | 99 | def _remap(self, remap_name, gt, pred): 100 | if remap_name == 'none': 101 | return (gt, pred) 102 | 103 | # TODO: cleanup and document 104 | gt_new = 255 + torch.zeros_like(gt) 105 | 106 | n, _, h, w = pred.shape 107 | c = self._remap_len(remap_name) 108 | device = pred.device 109 | dtype = pred.dtype 110 | 111 | pred_new = torch.zeros(n, c, h, w, device=device, dtype=dtype) 112 | 113 | for cls_to, clss_from in enumerate(self.REMAPS[remap_name]): 114 | clss_from = (clss_from, ) if isinstance(clss_from, int) else clss_from 115 | 116 | for cls_from in clss_from: 117 | gt_new[gt == cls_from] = cls_to 118 | pred_new[:,cls_to,:,:] += pred[:,cls_from,:,:] 119 | 120 | return (gt_new, pred_new) 121 | 122 | def update(self, gt, pred): 123 | pred = pred.exp() 124 | 125 | for remap_name, score in self.items(): 126 | gt_remap, pred_remap = self._remap(remap_name, gt, pred) # nichts passiert 127 | 128 | score.update( 129 | gt_remap.numpy(), 130 | pred_remap.argmax(1).cpu().numpy() 131 | ) 132 | 133 | def reset(self): 134 | for remap_name, score in self.items(): 135 | score.reset() 136 | 137 | def items(self): 138 | return iter((key, self[key]) for key in self) 139 | 140 | def __iter__(self): 141 | return iter(self.scores) 142 | 143 | def __getitem__(self, key): 144 | return self.scores[key] 145 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifnspaml/SGDepth/7d594b5153801d8240ec809ef6065eff1cd1f4fc/models/__init__.py -------------------------------------------------------------------------------- /models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .grad_scaling_layers import ScaledSplit, GRL 2 | -------------------------------------------------------------------------------- /models/layers/grad_scaling_layers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.autograd as autograd 6 | 7 | 8 | class ScaleGrad(autograd.Function): 9 | @staticmethod 10 | def forward(ctx, scale, *inputs): 11 | ctx.scale = scale 12 | 13 | outputs = inputs[0] if (len(inputs) == 1) else inputs 14 | 15 | return outputs 16 | 17 | @staticmethod 18 | def backward(ctx, *grad_outputs): 19 | grad_inputs = tuple( 20 | ctx.scale * grad 21 | for grad in grad_outputs 22 | ) 23 | 24 | return (None, *grad_inputs) 25 | 26 | 27 | class ScaledSplit(nn.Module): 28 | """Identity maps an input into outputs and scale gradients in the backward pass 29 | 30 | Args: 31 | *grad_weights: one or multiple weights to apply to the gradients in 32 | the backward pass 33 | 34 | Examples: 35 | 36 | >>> # Multiplex to two outputs, the gradients are scaled 37 | >>> # by 0.3 and 0.7 respectively 38 | >>> scp = ScaledSplit(0.3, 0.7) 39 | >>> # The input may consist of multiple tensors 40 | >>> inp 41 | (tensor(...), tensor(...)) 42 | >>> otp1, otp2 = scp(inp) 43 | >>> # Considering the forward pass both outputs behave just like inp. 44 | >>> # In the backward pass the gradients will be scaled by the respective 45 | >>> # weights 46 | >>> otp1 47 | (tensor(...), tensor(...)) 48 | >>> otp2 49 | (tensor(...), tensor(...)) 50 | """ 51 | 52 | def __init__(self, *grad_weights): 53 | super().__init__() 54 | self.set_scales(*grad_weights) 55 | 56 | def set_scales(self, *grad_weights): 57 | self.grad_weights = grad_weights 58 | 59 | def get_scales(self, *grad_weights): 60 | return self.grad_weights 61 | 62 | def forward(self, *inputs): 63 | # Generate nested tuples, where the outer layer 64 | # corresponds to the output & grad_weight pairs 65 | # and the inner layer corresponds to the list of inputs 66 | split = tuple( 67 | tuple(ScaleGrad.apply(gw, inp) for inp in inputs) 68 | for gw in self.grad_weights 69 | ) 70 | 71 | # Users that passed only one input don't expect 72 | # a nested tuple as output but rather a tuple of tensors, 73 | # so unpack if there was only one input 74 | unnest_inputs = tuple( 75 | s[0] if (len(s) == 1) else s 76 | for s in split 77 | ) 78 | 79 | # Users that specified only one output weight 80 | # do not expect a tuple of tensors put rather 81 | # a single tensor, so unpack if there was only one weight 82 | unnest_outputs = unnest_inputs[0] if (len(unnest_inputs) == 1) else unnest_inputs 83 | 84 | return unnest_outputs 85 | 86 | 87 | class GRL(ScaledSplit): 88 | """Identity maps an input and invert the gradient in the backward pass 89 | 90 | This layer can be used in adversarial training to train an encoder 91 | encoder network to become _worse_ a a specific task. 92 | """ 93 | 94 | def __init__(self): 95 | super().__init__(-1) 96 | 97 | 98 | class TestScaledSplit(unittest.TestCase): 99 | def test_siso(self): 100 | factor = 0.5 101 | 102 | scp = ScaledSplit(factor) 103 | 104 | # Construct a toy network with inputs and weights 105 | inp = torch.tensor([ 1, 1, 1], dtype=torch.float32, requires_grad=True) 106 | wgt = torch.tensor([-1, 0, 1], dtype=torch.float32, requires_grad=False) 107 | 108 | pre_split = inp * wgt 109 | post_split = scp.forward(pre_split) 110 | 111 | self.assertTrue(torch.equal(pre_split, post_split), 'ScaledSplit produced non-identity in forward pass') 112 | 113 | # The network's output is a single number 114 | sum_pre = pre_split.sum() 115 | sum_post = post_split.sum() 116 | 117 | # Compute the gradients with and withou scaling 118 | grad_pre, = autograd.grad(sum_pre, inp, retain_graph=True) 119 | grad_post, = autograd.grad(sum_post, inp) 120 | 121 | # Check if the scaling matches expectations 122 | self.assertTrue(torch.equal(grad_pre * factor, grad_post), 'ScaledSplit produced inconsistent gradient') 123 | 124 | def test_simo(self): 125 | # TODO 126 | 127 | pass 128 | 129 | def test_miso(self): 130 | # TODO 131 | 132 | pass 133 | 134 | 135 | def test_mimo(self): 136 | # TODO 137 | 138 | pass 139 | 140 | 141 | if __name__ == '__main__': 142 | # Use python3 -m grad_scaling_layers 143 | # to run the unit tests and check if you 144 | # broke something 145 | 146 | unittest.main() 147 | -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .partial_decoder import PartialDecoder 2 | from .multi_res_output import MultiResDepth, MultiResSegmentation 3 | from .resnet_encoder import ResnetEncoder 4 | from .pose_decoder import PoseDecoder 5 | -------------------------------------------------------------------------------- /models/networks/multi_res_output.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MultiRes(nn.Module): 6 | """ Directly generate target-space outputs from (intermediate) decoder layer outputs 7 | Args: 8 | dec_chs: A list of decoder output channel counts 9 | out_chs: output channels to generate 10 | pp: A function to call on any output tensor 11 | for post-processing (like e.g. a non linear activation) 12 | """ 13 | 14 | def __init__(self, dec_chs, out_chs, pp=None): 15 | super().__init__() 16 | 17 | self.pad = nn.ReflectionPad2d(1) 18 | 19 | self.convs = nn.ModuleList( 20 | nn.Conv2d(in_chs, out_chs, 3) 21 | for in_chs in dec_chs[::-1] 22 | ) 23 | 24 | self.pp = pp if (pp is not None) else self._identity_pp 25 | 26 | def _identity_pp(self, x): 27 | return x 28 | 29 | def forward(self, *x): 30 | out = tuple( 31 | self.pp(conv(self.pad(inp))) 32 | for conv, inp in zip(self.convs[::-1], x) 33 | ) 34 | 35 | return out 36 | 37 | 38 | class MultiResDepth(MultiRes): 39 | def __init__(self, dec_chs, out_chs=1): 40 | super().__init__(dec_chs, out_chs, nn.Sigmoid()) 41 | 42 | # Just like in the PoseDecoder, where outputting 43 | # large translations at the beginning of training 44 | # is harmful for stability outputting large depths 45 | # at the beginning of training is a source 46 | # of instability as well. Increasing the bias on 47 | # the disparity output decreases the depth output. 48 | for conv in self.convs: 49 | conv.bias.data += 5 50 | 51 | 52 | class MultiResSegmentation(MultiRes): 53 | def __init__(self, dec_chs, out_chs=20): 54 | super().__init__(dec_chs, out_chs) 55 | 56 | -------------------------------------------------------------------------------- /models/networks/partial_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class PreConvBlock(nn.Module): 5 | """Decoder basic block 6 | """ 7 | 8 | def __init__(self, pos, n_in, n_out): 9 | super().__init__() 10 | self.pos = pos 11 | 12 | self.pad = nn.ReflectionPad2d(1) 13 | self.conv = nn.Conv2d(n_in, n_out, 3) 14 | self.nl = nn.ELU() 15 | 16 | def forward(self, *x): 17 | if self.pos == 0: 18 | x_pre = x[:self.pos] 19 | x_cur = x[self.pos] 20 | x_pst = x[self.pos + 1:] 21 | else: 22 | x_pre = x[:self.pos] 23 | x_cur = x[self.pos - 1] 24 | x_pst = x[self.pos + 1:] 25 | 26 | x_cur = self.pad(x_cur) 27 | x_cur = self.conv(x_cur) 28 | x_cur = self.nl(x_cur) 29 | 30 | return x_pre + (x_cur, ) + x_pst 31 | 32 | class UpSkipBlock(nn.Module): 33 | """Decoder basic block 34 | 35 | Perform the following actions: 36 | - Upsample by factor 2 37 | - Concatenate skip connections (if any) 38 | - Convolve 39 | """ 40 | 41 | def __init__(self, pos, ch_in, ch_skip, ch_out): 42 | super().__init__() 43 | self.pos = pos 44 | 45 | self.up = nn.Upsample(scale_factor=2) 46 | 47 | self.pad = nn.ReflectionPad2d(1) 48 | self.conv = nn.Conv2d(ch_in + ch_skip, ch_out, 3) 49 | self.nl = nn.ELU() 50 | 51 | def forward(self, *x): 52 | if self.pos == 5: 53 | x_pre = x[:self.pos - 1 ] 54 | x_new = x[self.pos - 1] 55 | x_skp = tuple() 56 | x_pst = x[self.pos:] 57 | else: 58 | x_pre = x[:self.pos - 1] 59 | x_new = x[self.pos - 1] 60 | x_skp = x[self.pos] 61 | x_pst = x[self.pos:] 62 | 63 | # upscale the input: 64 | x_new = self.up(x_new) 65 | 66 | # Mix in skip connections from the encoder 67 | # (if there are any) 68 | if len(x_skp) > 0: 69 | x_new = torch.cat((x_new, x_skp), 1) 70 | 71 | # Combine up-scaled input and skip connections 72 | x_new = self.pad(x_new) 73 | x_new = self.conv(x_new) 74 | x_new = self.nl(x_new) 75 | 76 | return x_pre + (x_new, ) + x_pst 77 | 78 | class PartialDecoder(nn.Module): 79 | """Decode some features encoded by a feature extractor 80 | 81 | Args: 82 | chs_dec: A list of decoder-internal channel counts 83 | chs_enc: A list of channel counts that we get from the encoder 84 | start: The first step of the decoding process this decoder should perform 85 | end: The last step of the decoding process this decoder should perform 86 | """ 87 | 88 | def __init__(self, chs_dec, chs_enc, start=0, end=None): 89 | super().__init__() 90 | 91 | self.start = start 92 | self.end = (2 * len(chs_dec)) if (end is None) else end 93 | 94 | self.chs_dec = tuple(chs_dec) 95 | self.chs_enc = tuple(chs_enc) 96 | 97 | self.blocks = nn.ModuleDict() 98 | 99 | for step in range(self.start, self.end): 100 | macro_step = step // 2 101 | mini_step = step % 2 102 | pos_x = (step + 1) // 2 103 | 104 | # The decoder steps are interleaved ... 105 | if (mini_step == 0): 106 | n_in = self.chs_dec[macro_step - 1] if (macro_step > 0) else self.chs_enc[0] 107 | n_out = self.chs_dec[macro_step] 108 | 109 | # ... first there is a pre-convolution ... 110 | self.blocks[f'step_{step}'] = PreConvBlock(pos_x, n_in, n_out) 111 | 112 | else: 113 | # ... and then an upsampling and convolution with 114 | # the skip connections input. 115 | n_in = self.chs_dec[macro_step] 116 | n_skips = self.chs_enc[macro_step + 1] if ((macro_step + 1) < len(chs_enc)) else 0 117 | n_out = self.chs_dec[macro_step] 118 | 119 | self.blocks[f'step_{step}'] = UpSkipBlock(pos_x, n_in, n_skips, n_out) 120 | 121 | def chs_x(self): 122 | return self.chs_dec 123 | 124 | @classmethod 125 | def gen_head(cls, chs_dec, chs_enc, end=None): 126 | return cls(chs_dec, chs_enc, 0, end) 127 | 128 | @classmethod 129 | def gen_tail(cls, head, end=None): 130 | return cls(head.chs_dec, head.chs_enc, head.end, end) 131 | 132 | def forward(self, *x): 133 | for step in range(self.start, self.end): 134 | x = self.blocks[f'step_{step}'](*x) 135 | return x 136 | -------------------------------------------------------------------------------- /models/networks/pose_decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class PoseDecoder(nn.Module): 4 | def __init__(self, input_channels): 5 | super().__init__() 6 | 7 | self.nl = nn.ReLU() 8 | self.squeeze = nn.Conv2d(input_channels, 256, 1) 9 | self.conv_1 = nn.Conv2d(256, 256, 3, 1, 1) 10 | self.conv_2 = nn.Conv2d(256, 256, 3, 1, 1) 11 | self.conv_3 = nn.Conv2d(256, 6, 1) 12 | 13 | # The original monodepth2 PoseDecoder 14 | # included a constant multiplication by 15 | # 0.01 in the forward pass, possibly to 16 | # make x_angle and x_translation tiny at 17 | # the beginning of training for stability. 18 | # In my opinion this hurts performance 19 | # with weight_decay enabled. 20 | # Scaling the initial weights should have 21 | # a similar effect. 22 | self.conv_3.weight.data *= 0.01 23 | self.conv_3.bias.data *= 0.01 24 | 25 | def forward(self, x): 26 | x = self.squeeze(x) 27 | x = self.nl(x) 28 | 29 | x = self.conv_1(x) 30 | x = self.nl(x) 31 | 32 | x = self.conv_2(x) 33 | x = self.nl(x) 34 | 35 | x = self.conv_3(x) 36 | x = x.mean((3, 2)).view(-1, 1, 1, 6) 37 | 38 | x_angle = x[..., :3] 39 | x_translation = x[..., 3:] 40 | 41 | return x_angle, x_translation 42 | -------------------------------------------------------------------------------- /models/networks/resnet_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | 5 | RESNETS = { 6 | 18: models.resnet18, 7 | 34: models.resnet34, 8 | 50: models.resnet50, 9 | 101: models.resnet101, 10 | 152: models.resnet152 11 | } 12 | 13 | 14 | class ResnetEncoder(nn.Module): 15 | """A ResNet that handles multiple input images and outputs skip connections""" 16 | 17 | def __init__(self, num_layers, pretrained, num_input_images=1): 18 | super().__init__() 19 | 20 | if num_layers not in RESNETS: 21 | raise ValueError(f"{num_layers} is not a valid number of resnet layers") 22 | 23 | self.encoder = RESNETS[num_layers](pretrained) 24 | 25 | # Up until this point self.encoder handles 3 input channels. 26 | # For pose estimation we want to use two input images, 27 | # which means 6 input channels. 28 | # Extend the encoder in a way that makes it equivalent 29 | # to the single-image version when fed with an input image 30 | # repeated num_input_images times. 31 | # Further Information is found in the appendix Section B of: 32 | # https://arxiv.org/pdf/1806.01260.pdf 33 | # Mind that in this step only the weights are changed 34 | # to handle 6 (or even more) input channels 35 | # For clarity the attribute "in_channels" should be changed too, 36 | # although it seems to me that it has no influence on the functionality 37 | self.encoder.conv1.weight.data = self.encoder.conv1.weight.data.repeat( 38 | (1, num_input_images, 1, 1) 39 | ) / num_input_images 40 | 41 | # Change attribute "in_channels" for clarity 42 | self.encoder.conv1.in_channels = num_input_images * 3 # Number of channels for a picture = 3 43 | 44 | # Remove fully connected layer 45 | self.encoder.fc = None 46 | 47 | if num_layers > 34: 48 | self.num_ch_enc = (64, 256, 512, 1024, 2048) 49 | else: 50 | self.num_ch_enc = (64, 64, 128, 256, 512) 51 | 52 | def forward(self, l_0): 53 | l_0 = self.encoder.conv1(l_0) 54 | l_0 = self.encoder.bn1(l_0) 55 | l_0 = self.encoder.relu(l_0) 56 | 57 | l_1 = self.encoder.maxpool(l_0) 58 | l_1 = self.encoder.layer1(l_1) 59 | 60 | l_2 = self.encoder.layer2(l_1) 61 | l_3 = self.encoder.layer3(l_2) 62 | l_4 = self.encoder.layer4(l_3) 63 | 64 | return (l_0, l_1, l_2, l_3, l_4) 65 | -------------------------------------------------------------------------------- /models/sgdepth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from . import networks 5 | from . import layers 6 | 7 | 8 | class SGDepthCommon(nn.Module): 9 | def __init__(self, num_layers, split_pos, grad_scales=(0.9, 0.1), pretrained=False): 10 | super().__init__() 11 | 12 | self.encoder = networks.ResnetEncoder(num_layers, pretrained) 13 | self.num_layers = num_layers # This information is needed in the train loop for the sequential training 14 | 15 | # Number of channels for the skip connections and internal connections 16 | # of the decoder network, ordered from input to output 17 | self.shape_enc = tuple(reversed(self.encoder.num_ch_enc)) 18 | self.shape_dec = (256, 128, 64, 32, 16) 19 | 20 | self.decoder = networks.PartialDecoder.gen_head(self.shape_dec, self.shape_enc, split_pos) 21 | self.split = layers.ScaledSplit(*grad_scales) 22 | 23 | def set_gradient_scales(self, depth, segmentation): 24 | self.split.set_scales(depth, segmentation) 25 | 26 | def get_gradient_scales(self): 27 | return self.split.get_scales() 28 | 29 | def forward(self, x): 30 | # The encoder produces outputs in the order 31 | # (highest res, second highest res, …, lowest res) 32 | x = self.encoder(x) 33 | 34 | # The decoder expects it's inputs in the order they are 35 | # used. E.g. (lowest res, second lowest res, …, highest res) 36 | x = tuple(reversed(x)) 37 | 38 | # Replace some elements in the x tuple by decoded 39 | # tensors and leave others as-is 40 | x = self.decoder(*x) # CHANGE ME BACK TO THIS 41 | 42 | # Setup gradient scaling in the backward pass 43 | x = self.split(*x) 44 | 45 | # Experimental Idea: We want the decoder layer to be trained, so pass x to the decoder AFTER x was passed 46 | # to self.split which scales all gradients to 0 (if grad_scales are 0) 47 | # x = (self.decoder(*x[0]), ) + (self.decoder(*x[1]), ) + (self.decoder(*x[2]), ) 48 | 49 | return x 50 | 51 | #def get_last_shared_layer(self): 52 | # return self.encoder.encoder.layer4 53 | 54 | 55 | class SGDepthDepth(nn.Module): 56 | def __init__(self, common, resolutions=1): 57 | super().__init__() 58 | 59 | self.resolutions = resolutions 60 | 61 | self.decoder = networks.PartialDecoder.gen_tail(common.decoder) 62 | self.multires = networks.MultiResDepth(self.decoder.chs_x()[-resolutions:]) 63 | 64 | def forward(self, *x): 65 | x = self.decoder(*x) 66 | x = self.multires(*x[-self.resolutions:]) 67 | return x 68 | 69 | 70 | class SGDepthSeg(nn.Module): 71 | def __init__(self, common): 72 | super().__init__() 73 | 74 | self.decoder = networks.PartialDecoder.gen_tail(common.decoder) 75 | self.multires = networks.MultiResSegmentation(self.decoder.chs_x()[-1:]) 76 | self.nl = nn.Softmax2d() 77 | 78 | def forward(self, *x): 79 | x = self.decoder(*x) 80 | x = self.multires(*x[-1:]) 81 | x_lin = x[-1] 82 | 83 | return x_lin 84 | 85 | 86 | class SGDepthPose(nn.Module): 87 | def __init__(self, num_layers, pretrained=False): 88 | super().__init__() 89 | 90 | self.encoder = networks.ResnetEncoder( 91 | num_layers, pretrained, num_input_images=2 92 | ) 93 | 94 | self.decoder = networks.PoseDecoder(self.encoder.num_ch_enc[-1]) 95 | 96 | def _transformation_from_axisangle(self, axisangle): 97 | n, h, w = axisangle.shape[:3] 98 | 99 | angles = axisangle.norm(dim=3) 100 | axes = axisangle / (angles.unsqueeze(-1) + 1e-7) 101 | 102 | # Implement the matrix from [1] with an additional identity fourth dimension 103 | # [1]: https://en.wikipedia.org/wiki/Transformation_matrix#Rotation_2 104 | 105 | angles_cos = angles.cos() 106 | angles_sin = angles.sin() 107 | 108 | res = torch.zeros(n, h, w, 4, 4, device=axisangle.device) 109 | res[:,:,:,:3,:3] = (1 - angles_cos.view(n,h,w,1,1)) * (axes.unsqueeze(-2) * axes.unsqueeze(-1)) 110 | 111 | res[:,:,:,0,0] += angles_cos 112 | res[:,:,:,1,1] += angles_cos 113 | res[:,:,:,2,2] += angles_cos 114 | 115 | sl = axes[:,:,:,0] * angles_sin 116 | sm = axes[:,:,:,1] * angles_sin 117 | sn = axes[:,:,:,2] * angles_sin 118 | 119 | res[:,:,:,0,1] -= sn 120 | res[:,:,:,1,0] += sn 121 | 122 | res[:,:,:,1,2] -= sl 123 | res[:,:,:,2,1] += sl 124 | 125 | res[:,:,:,2,0] -= sm 126 | res[:,:,:,0,2] += sm 127 | 128 | res[:,:,:,3,3] = 1.0 129 | 130 | return res 131 | 132 | def _transformation_from_translation(self, translation): 133 | n, h, w = translation.shape[:3] 134 | 135 | # Implement the matrix from [1] with an additional dimension 136 | # [1]: https://en.wikipedia.org/wiki/Transformation_matrix#Affine_transformations 137 | 138 | res = torch.zeros(n, h, w, 4, 4, device=translation.device) 139 | res[:,:,:,:3,3] = translation 140 | res[:,:,:,0,0] = 1.0 141 | res[:,:,:,1,1] = 1.0 142 | res[:,:,:,2,2] = 1.0 143 | res[:,:,:,3,3] = 1.0 144 | 145 | return res 146 | 147 | def forward(self, x, invert): 148 | x = self.encoder(x) 149 | x = x[-1] # take only the feature map of the last layer ... 150 | 151 | x_axisangle, x_translation = self.decoder(x) # ... and pass it through the decoder 152 | 153 | x_rotation = self._transformation_from_axisangle(x_axisangle) 154 | 155 | if not invert: 156 | x_translation = self._transformation_from_translation(x_translation) 157 | 158 | return x_translation @ x_rotation 159 | 160 | else: 161 | x_rotation = x_rotation.transpose(3, 4) 162 | x_translation = -x_translation 163 | 164 | x_translation = self._transformation_from_translation(x_translation) 165 | 166 | return x_rotation @ x_translation 167 | 168 | 169 | class SGDepth(nn.Module): 170 | KEY_FRAME_CUR = ('color_aug', 0, 0) 171 | KEY_FRAME_PREV = ('color_aug', -1, 0) 172 | KEY_FRAME_NEXT = ('color_aug', 1, 0) 173 | 174 | def __init__(self, split_pos=1, num_layers=18, grad_scale_depth=0.95, grad_scale_seg=0.05, 175 | weights_init='pretrained', resolutions_depth=1, num_layers_pose=18): 176 | 177 | super().__init__() 178 | 179 | # sgdepth allowed for five possible split positions. 180 | # The PartialDecoder developed as part of sgdepth 181 | # is a bit more flexible and allows splits to be 182 | # placed in between sgdepths splits. 183 | # As this class is meant to maximize compatibility 184 | # with sgdepth the line below translates between 185 | # the split position definitions. 186 | split_pos = max((2 * split_pos) - 1, 0) 187 | 188 | # The Depth and the Segmentation Network have a common (=shared) 189 | # Encoder ("Feature Extractor") 190 | self.common = SGDepthCommon( 191 | num_layers, split_pos, (grad_scale_depth, grad_scale_seg), 192 | weights_init == 'pretrained' 193 | ) 194 | 195 | # While Depth and Seg Network have a shared Encoder, 196 | # each one has it's own Decoder 197 | self.depth = SGDepthDepth(self.common, resolutions_depth) 198 | self.seg = SGDepthSeg(self.common) 199 | 200 | # The Pose network has it's own Encoder ("Feature Extractor") and Decoder 201 | self.pose = SGDepthPose( 202 | num_layers_pose, 203 | weights_init == 'pretrained' 204 | ) 205 | 206 | def _batch_pack(self, group): 207 | # Concatenate a list of tensors and remember how 208 | # to tear them apart again 209 | 210 | group = tuple(group) 211 | 212 | dims = tuple(b.shape[0] for b in group) # dims = (DEFAULT_DEPTH_BATCH_SIZE, DEFAULT_SEG_BATCH_SIZE) 213 | group = torch.cat(group, dim=0) # concatenate along the first axis, so along the batch axis 214 | 215 | return dims, group 216 | 217 | def _multi_batch_unpack(self, dims, *xs): 218 | xs = tuple( 219 | tuple(x.split(dims)) 220 | for x in xs 221 | ) 222 | 223 | # xs, as of now, is indexed like this: 224 | # xs[ENCODER_LAYER][DATASET_IDX], the lines below swap 225 | # this around to xs[DATASET_IDX][ENCODER_LAYER], so that 226 | # xs[DATASET_IDX] can be fed into the decoders. 227 | xs = tuple(zip(*xs)) 228 | 229 | return xs 230 | 231 | def _check_purposes(self, dataset, purpose): 232 | # mytransforms.AddKeyValue is used in the loaders 233 | # to give each image a tuple of 'purposes'. 234 | # As of now these purposes can be 'depth' and 'segmentation'. 235 | # The torch DataLoader collates these per-image purposes 236 | # into list of them for each batch. 237 | # Check all purposes in this collated list for the requested 238 | # purpose (if you did not do anything wonky all purposes in a 239 | # batch should be equal), 240 | 241 | for purpose_field in dataset['purposes']: 242 | if purpose_field[0] == purpose: 243 | return True 244 | 245 | def set_gradient_scales(self, depth, segmentation): 246 | self.common.set_gradient_scales(depth, segmentation) 247 | 248 | def get_gradient_scales(self): 249 | return self.common.get_gradient_scales() 250 | 251 | def forward(self, batch): 252 | # Stitch together all current input frames 253 | # in the input group. So that batch normalization 254 | # in the encoder is done over all datasets/domains. 255 | dims, x = self._batch_pack( 256 | dataset[self.KEY_FRAME_CUR] 257 | for dataset in batch 258 | ) 259 | 260 | # Feed the stitched-together input tensor through 261 | # the common network part and generate two output 262 | # tuples that look exactly the same in the forward 263 | # pass, but scale gradients differently in the backward pass. 264 | x_depth, x_seg = self.common(x) 265 | 266 | # Cut the stitched-together tensors along the 267 | # dataset boundaries so further processing can 268 | # be performed on a per-dataset basis. 269 | # x[DATASET_IDX][ENCODER_LAYER] 270 | x_depth = self._multi_batch_unpack(dims, *x_depth) 271 | x_seg = self._multi_batch_unpack(dims, *x_seg) 272 | 273 | outputs = list(dict() for _ in batch) 274 | 275 | # All the way back in the loaders each dataset is assigned one or more 'purposes'. 276 | # For datasets with the 'depth' purpose set the outputs[DATASET_IDX] dict will be 277 | # populated with depth outputs. 278 | # Datasets with the 'segmentation' purpose are handled accordingly. 279 | # If the pose outputs are populated is dependant upon the presence of 280 | # ('color_aug', -1, 0)/('color_aug', 1, 0) keys in the Dataset. 281 | for idx, dataset in enumerate(batch): 282 | if self._check_purposes(dataset, 'depth'): 283 | x = x_depth[idx] 284 | x = self.depth(*x) 285 | x = reversed(x) 286 | 287 | for res, disp in enumerate(x): 288 | outputs[idx]['disp', res] = disp 289 | 290 | if self._check_purposes(dataset, 'segmentation'): 291 | x = x_seg[idx] 292 | x = self.seg(*x) 293 | 294 | outputs[idx]['segmentation_logits', 0] = x 295 | 296 | if self.KEY_FRAME_PREV in dataset: 297 | frame_prev = dataset[self.KEY_FRAME_PREV] 298 | frame_cur = dataset[self.KEY_FRAME_CUR] 299 | 300 | # Concatenating joins the previous and the current frame 301 | # tensors along the first axis/dimension, 302 | # which is the axis for the color channel 303 | frame_prev_cur = torch.cat((frame_prev, frame_cur), dim=1) 304 | 305 | outputs[idx]['cam_T_cam', 0, -1] = self.pose(frame_prev_cur, invert=True) 306 | 307 | if self.KEY_FRAME_NEXT in dataset: 308 | frame_cur = dataset[self.KEY_FRAME_CUR] 309 | frame_next = dataset[self.KEY_FRAME_NEXT] 310 | 311 | frame_cur_next = torch.cat((frame_cur, frame_next), 1) 312 | outputs[idx]['cam_T_cam', 0, 1] = self.pose(frame_cur_next, invert=False) 313 | 314 | return tuple(outputs) 315 | -------------------------------------------------------------------------------- /perspective_resample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as functional 3 | 4 | 5 | class PerspectiveResampler(object): 6 | def __init__(self, max_depth=100, min_depth=0.01, min_sampling_res=3): 7 | self.min_disp = 1 / max_depth 8 | self.max_disp = 1 / min_depth 9 | 10 | self.min_sampling_res = min_sampling_res 11 | 12 | if self.min_sampling_res < 3: 13 | raise ValueError( 14 | 'Bilinear sampling needs at least a 2x2 image to sample from. ' 15 | 'Increase --min_sampling_res to at least 3.' 16 | ) 17 | 18 | def _homogeneous_grid_like(self, ref): 19 | n, c, h, w = ref.shape 20 | 21 | grid_x = torch.linspace(0, w - 1, w, device=ref.device) 22 | grid_y = torch.linspace(0, h - 1, h, device=ref.device) 23 | 24 | grid_x = grid_x.view(1, 1, -1, 1).expand(n, h, w, 1) 25 | grid_y = grid_y.view(1, -1, 1, 1).expand(n, h, w, 1) 26 | 27 | grid_w = torch.ones(n, h, w, 1, device=ref.device) 28 | 29 | grid = torch.cat((grid_x, grid_y, grid_w), 3) 30 | 31 | return grid 32 | 33 | def _to_pointcloud(self, depth, cam_inv): 34 | # Generate a grid that resembles the coordinates 35 | # of the projection plane 36 | grid = self._homogeneous_grid_like(depth) 37 | 38 | # Use the inverse camera matrix to generate 39 | # a vector that points from the camera focal 40 | # point to the pixel position on the projection plane. 41 | # Use unsqueeze and squeeze to coax pytorch into 42 | # performing a matrix - vector product. 43 | pointcloud = (cam_inv @ grid.unsqueeze(-1)).squeeze(-1) 44 | 45 | # Transpose the depth from shape (n, 1, h, w) to (n, h, w, 1) 46 | # to match the pointcloud shape (n, h, w, 3) and 47 | # multiply the projection plane vectors by the depth. 48 | pointcloud = pointcloud * depth.permute(0, 2, 3, 1) 49 | 50 | # Make the pointcloud coordinates homogeneous 51 | # by extending each vector with a one. 52 | pointcloud = torch.cat( 53 | (pointcloud, torch.ones_like(pointcloud[:,:,:,:1])), 54 | 3 55 | ) 56 | 57 | return pointcloud 58 | 59 | def _surface_normal(self, pointcloud): 60 | # Pointcloud has shape (n, h, w, 4). 61 | # Calculate the vectors pointing from a point associated with 62 | # a pixel to the points associated with the pixels down and to the right. 63 | dy = pointcloud[:,1:,:-1,:3] - pointcloud[:,:-1,:-1,:3] 64 | dx = pointcloud[:,:-1,1:,:3] - pointcloud[:,:-1,:-1,:3] 65 | 66 | # Calculate the normal vector for the plane spanned 67 | # by the vectors above. 68 | n = torch.cross(dx, dy, 3) 69 | n = n / (n.norm(2, 3, True) + 1e-8) 70 | 71 | return n 72 | 73 | def _change_perspective(self, pointcloud, cam_to_cam): 74 | """ Use the translation/rotation matrix to move 75 | a pointcloud between reference frames 76 | """ 77 | 78 | # Use unsqueeze and squeeze to coax pytorch into 79 | # performing a matrix - vector product. 80 | pointcloud = (cam_to_cam @ pointcloud.unsqueeze(-1)).squeeze(-1) 81 | 82 | return pointcloud 83 | 84 | def _to_sample_grid(self, pointcloud, cam): 85 | # Project the pointcloud onto a projection 86 | # plane using the camera matrix 87 | grid = (cam @ pointcloud.unsqueeze(-1)).squeeze(-1) 88 | 89 | # Now grid has shape (n, h, w, 3). 90 | # Each pixel contains a 3-dimensional homogeneous coordinate. 91 | # To get to x,y coordinates the first two elements of 92 | # the homogeneous coordinates have to be divided by the third. 93 | grid = grid[:,:,:,:2] / (grid[:,:,:,2:3] + 1e-7) 94 | 95 | # At this point grid contains sampling coordinates in pixels 96 | # but grid_sample works with sampling coordinates in the range -1,1 97 | # so some rescaling has to be applied. 98 | h, w = grid.shape[1:3] 99 | dim = torch.tensor((w - 1, h - 1), dtype=grid.dtype, device=grid.device) 100 | grid = 2 * grid / dim - 1 101 | 102 | return grid 103 | 104 | def _shape_cam(self, inv_cam, cam): 105 | # We only need the top 3x4 of the camera matrix for projection 106 | # from pointcloud to homogeneous grid positions 107 | cam = cam[:,:3,:] 108 | 109 | # We only need the top 3x3 of the inverse matrix 110 | # for projection to pointcloud 111 | inv_cam = inv_cam[:,:3,:3] 112 | 113 | # Take the camera matrix from shape (N, 3, 4) to 114 | # (N, 1, 1, 3, 4) to match the pointcloud shape (N, H, W, 4). 115 | cam = cam.unsqueeze(1).unsqueeze(2) 116 | 117 | # Take the inverse camera matrix from shape (N, 3, 3) 118 | # to (N, 1, 1, 3, 3) to match the grid shape (N, H, W, 3). 119 | inv_cam = inv_cam.unsqueeze(1).unsqueeze(2) 120 | 121 | return inv_cam, cam 122 | 123 | def scale_disp(self, disp): 124 | return self.min_disp + (self.max_disp - self.min_disp) * disp 125 | 126 | def warp_images(self, inputs, outputs, outputs_masked): 127 | """Generate the warped (reprojected) color images for a minibatch. 128 | """ 129 | 130 | predictions = dict() 131 | resolutions = tuple(frozenset(k[1] for k in outputs if k[0] == 'disp')) 132 | frame_ids = tuple(frozenset(k[2] for k in outputs if k[0] == 'cam_T_cam')) 133 | 134 | inv_cam, cam = self._shape_cam(inputs["inv_K", 0], inputs["K", 0]) 135 | 136 | disps = tuple( 137 | functional.interpolate( 138 | outputs["disp", res], scale_factor=2**res, 139 | mode="bilinear", align_corners=False 140 | ) 141 | for res in resolutions 142 | ) 143 | 144 | depths = tuple( 145 | 1 / self.scale_disp(disp) 146 | for disp in disps 147 | ) 148 | 149 | # Take the pixel position in the target image and 150 | # the estimated depth to generate a point cloud of 151 | # target image pixels. 152 | pointclouds_source = tuple( 153 | self._to_pointcloud(depth, inv_cam) 154 | for depth in depths 155 | ) 156 | 157 | # Calculate per-pixel surface normals 158 | surface_normals = tuple( 159 | self._surface_normal(pointcloud) 160 | for pointcloud in pointclouds_source 161 | ) 162 | 163 | for frame_id in frame_ids: 164 | img_source = inputs["color", frame_id, 0] 165 | cam_to_cam = outputs["cam_T_cam", 0, frame_id] 166 | 167 | # Transfer the estimated pointclouds from one frame-of 168 | # reference to the other 169 | pointclouds_target = tuple( 170 | self._change_perspective(pointcloud, cam_to_cam) 171 | for pointcloud in pointclouds_source 172 | ) 173 | 174 | # Using the projection matrix, map this point cloud 175 | # to expected pixel coordinates in the source image. 176 | grids = tuple( 177 | self._to_sample_grid(pointcloud, cam) 178 | for pointcloud in pointclouds_target 179 | ) 180 | 181 | # Construct warped target images by sampling from the source image 182 | for res, grid in zip(resolutions, grids): 183 | # TODO check which align corners behaviour is desired, default is false, but originally it used to 184 | # be true, also for segmentation warping, define grid_sample for 1.1.0 version and for newest version 185 | img_pred = torch.nn.functional.grid_sample(img_source, grid, padding_mode="border") 186 | predictions["sample", frame_id, res] = grid 187 | predictions["color", frame_id, res] = img_pred 188 | 189 | # sample the warped segmentation image 190 | if outputs_masked is not None: 191 | shape = outputs_masked[("segmentation", frame_id, 0)].shape 192 | seg_source = outputs_masked[("segmentation", frame_id, 0)].reshape( 193 | (shape[0], 1, shape[1], shape[2])) 194 | seg_pred = torch.nn.functional.grid_sample(seg_source.float(), grids[0], padding_mode="border", 195 | mode='nearest').reshape((shape[0], shape[1], shape[2])) 196 | outputs_masked["segmentation_warped", frame_id, 0] = seg_pred 197 | 198 | for res, depth, surface_normal in zip(resolutions, depths, surface_normals): 199 | predictions["depth", 0, res] = depth 200 | predictions["normals_pointcloud", res] = surface_normal 201 | 202 | return predictions 203 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.18.1 2 | pillow>=6.1.0 3 | torchvision>=0.3.0 4 | tensorboardx>=2.0 5 | matplotlib>=3.1.3 6 | pyyaml>=5.3 7 | cudatoolkit=10.0 8 | cudnn=7.5.1 9 | pytorch==1.1.0 -------------------------------------------------------------------------------- /state_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.optim as optim 5 | 6 | import dataloader.file_io.get_path as get_path 7 | from models.sgdepth import SGDepth 8 | 9 | 10 | class ModelContext(object): 11 | def __init__(self, model, mode): 12 | self.model = model 13 | self.mode_wanted = mode 14 | 15 | def _set_mode(self, mode): 16 | if mode == 'train': 17 | self.model.train() 18 | elif mode == 'eval': 19 | self.model.eval() 20 | 21 | def __enter__(self): 22 | self.mode_was = 'train' if self.model.training else 'eval' 23 | self._set_mode(self.mode_wanted) 24 | 25 | return self.model 26 | 27 | def __exit__(self, *_): 28 | self._set_mode(self.mode_was) 29 | 30 | 31 | class ModelManager(object): 32 | def __init__(self, model): 33 | self.model = model 34 | 35 | def get_eval(self): 36 | return ModelContext(self.model, 'eval') 37 | 38 | def get_train(self): 39 | return ModelContext(self.model, 'train') 40 | 41 | 42 | class StateManager(object): 43 | def __init__(self, experiment_class, model_name, device, split_pos, num_layers, 44 | grad_scale_depth, grad_scale_seg, 45 | weights_init, resolutions_depth, num_layers_pose, 46 | learning_rate, weight_decay, scheduler_step_size): 47 | 48 | self.device = device 49 | 50 | path_getter = get_path.GetPath() 51 | self.log_base = path_getter.get_checkpoint_path() 52 | self.log_path = os.path.join(self.log_base, experiment_class, model_name) 53 | 54 | self._init_training() 55 | self._init_model( 56 | split_pos, num_layers, grad_scale_depth, grad_scale_seg, weights_init, resolutions_depth, 57 | num_layers_pose 58 | ) 59 | self._init_optimizer(learning_rate, weight_decay, scheduler_step_size) 60 | 61 | def _init_training(self): 62 | self.epoch = 0 63 | self.step = 0 64 | 65 | def _init_model(self, split_pos, num_layers, grad_scale_depth, grad_scale_seg, weights_init, resolutions_depth, 66 | num_layers_pose 67 | ): 68 | 69 | model = SGDepth(split_pos, num_layers, grad_scale_depth, grad_scale_seg, weights_init, 70 | resolutions_depth, num_layers_pose 71 | ) 72 | 73 | # noinspection PyUnresolvedReferences 74 | model = model.to(self.device) 75 | self.model_manager = ModelManager(model) 76 | 77 | def _init_optimizer(self, learning_rate, weight_decay, scheduler_step_size): 78 | with self.model_manager.get_train() as model: 79 | self.optimizer = optim.Adam(model.parameters(), learning_rate, weight_decay=weight_decay) 80 | 81 | self.lr_scheduler = optim.lr_scheduler.StepLR( 82 | self.optimizer, scheduler_step_size, learning_rate 83 | ) 84 | 85 | def _state_dir_paths(self, state_dir): 86 | return { 87 | 'model': os.path.join(self.log_base, state_dir, "model.pth"), 88 | 'optimizer': os.path.join(self.log_base, state_dir, "optim.pth"), 89 | 'scheduler': os.path.join(self.log_base, state_dir, "scheduler.pth"), 90 | 'train': os.path.join(self.log_base, state_dir, "train.pth"), 91 | } 92 | 93 | def store_state(self, state_dir): 94 | print(f"Storing model state to {state_dir}:") 95 | os.makedirs(state_dir, exist_ok=True) 96 | 97 | paths = self._state_dir_paths(state_dir) 98 | 99 | with self.model_manager.get_train() as model: 100 | torch.save(model.state_dict(), paths['model']) 101 | 102 | torch.save(self.optimizer.state_dict(), paths['optimizer']) 103 | torch.save(self.lr_scheduler.state_dict(), paths['scheduler']) 104 | 105 | state_train = { 106 | 'step': self.step, 107 | 'epoch': self.epoch, 108 | } 109 | 110 | torch.save(state_train, paths['train']) 111 | 112 | def store_checkpoint(self): 113 | state_dir = os.path.join(self.log_path, "checkpoints", f"epoch_{self.epoch}") 114 | self.store_state(state_dir) 115 | 116 | # Idea: log the model every batch to see how the training of the statistic parameters of the BN layers for the 117 | # shared encoder effect the validation 118 | def store_batch_checkpoint(self, batch_idx): 119 | state_dir = os.path.join(self.log_path, "checkpoints", f"batch_{batch_idx}") 120 | self.store_state(state_dir) 121 | 122 | def _load_model_state(self, path): 123 | with self.model_manager.get_train() as model: 124 | state = model.state_dict() 125 | to_load = torch.load(path, map_location=self.device) 126 | 127 | for (k, v) in to_load.items(): 128 | if k not in state: 129 | print(f" - WARNING: Model file contains unknown key {k} ({list(v.shape)})") 130 | 131 | for (k, v) in state.items(): 132 | if k not in to_load: 133 | print(f" - WARNING: Model file does not contain key {k} ({list(v.shape)})") 134 | 135 | else: 136 | state[k] = to_load[k] 137 | 138 | model.load_state_dict(state) 139 | 140 | def _load_optimizer_state(self, path): 141 | state = torch.load(path, map_location=self.device) 142 | self.optimizer.load_state_dict(state) 143 | 144 | def _load_scheduler_state(self, path): 145 | state = torch.load(path) 146 | self.lr_scheduler.load_state_dict(state) 147 | 148 | def _load_training_state(self, path): 149 | state = torch.load(path) 150 | self.step = state['step'] 151 | self.epoch = state['epoch'] 152 | 153 | def load(self, state_dir, disable_lr_loading=False): 154 | """Load model(s) from a state directory on disk 155 | """ 156 | 157 | print(f"Loading checkpoint from {os.path.join(self.log_base, state_dir)}:") 158 | 159 | paths = self._state_dir_paths(state_dir) 160 | 161 | print(f" - Loading model state from {paths['model']}:") 162 | try: 163 | self._load_model_state(paths['model']) 164 | except FileNotFoundError: 165 | print(" - Could not find model state file") 166 | if not disable_lr_loading: 167 | print(f" - Loading optimizer state from {paths['optimizer']}:") 168 | try: 169 | self._load_optimizer_state(paths['optimizer']) 170 | except FileNotFoundError: 171 | print(" - Could not find optimizer state file") 172 | except ValueError: 173 | print(" - Optimizer state file is incompatible with current setup") 174 | 175 | print(f" - Loading scheduler state from {paths['scheduler']}:") 176 | try: 177 | self._load_scheduler_state(paths['scheduler']) 178 | except FileNotFoundError: 179 | print(" - Could not find scheduler state file") 180 | 181 | print(f" - Loading training state from {paths['train']}:") 182 | try: 183 | self._load_training_state(paths['train']) 184 | except FileNotFoundError: 185 | print(" - Could not find training state file") 186 | -------------------------------------------------------------------------------- /timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class Timer(object): 4 | def __init__(self): 5 | self.time_prev = time.monotonic() 6 | self.category_prev = 'unaccounted' 7 | self.accumulators = dict() 8 | 9 | def enter(self, name): 10 | now = time.monotonic() 11 | delta = now - self.time_prev 12 | self.time_prev = now 13 | 14 | if self.category_prev not in self.accumulators: 15 | self.accumulators[self.category_prev] = 0 16 | 17 | self.accumulators[self.category_prev] += delta 18 | 19 | self.category_prev = name 20 | 21 | def leave(self): 22 | self.enter('unaccounted') 23 | 24 | def items(self): 25 | return self.accumulators.items() 26 | --------------------------------------------------------------------------------